Files
syndarix/backend/tests/services/test_oauth_service.py
Felipe Cardoso 16ee4e0cb3 Initial implementation of OAuth models, endpoints, and migrations
- Added models for `OAuthClient`, `OAuthState`, and `OAuthAccount`.
- Created Pydantic schemas to support OAuth flows, client management, and linked accounts.
- Implemented skeleton endpoints for OAuth Provider mode: authorization, token, and revocation.
- Updated router imports to include new `/oauth` and `/oauth/provider` routes.
- Added Alembic migration script to create OAuth-related database tables.
- Enhanced `users` table to allow OAuth-only accounts by making `password_hash` nullable.
2025-11-25 00:37:23 +01:00

404 lines
16 KiB
Python

# tests/services/test_oauth_service.py
"""
Tests for OAuthService covering authorization URL creation,
callback handling, and account management.
"""
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
from uuid import uuid4
import pytest
from app.core.exceptions import AuthenticationError
from app.crud.oauth import oauth_account, oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService
class TestGetEnabledProviders:
"""Tests for get_enabled_providers method."""
def test_returns_empty_when_disabled(self):
"""Test returns empty providers when OAuth is disabled."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = False
mock_settings.enabled_oauth_providers = []
result = OAuthService.get_enabled_providers()
assert result.enabled is False
assert result.providers == []
def test_returns_configured_providers(self):
"""Test returns configured providers when enabled."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["google", "github"]
result = OAuthService.get_enabled_providers()
assert result.enabled is True
assert len(result.providers) == 2
provider_names = [p.provider for p in result.providers]
assert "google" in provider_names
assert "github" in provider_names
def test_filters_unknown_providers(self):
"""Test filters out unknown providers from list."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["google", "unknown_provider"]
result = OAuthService.get_enabled_providers()
assert result.enabled is True
assert len(result.providers) == 1
assert result.providers[0].provider == "google"
class TestGetProviderCredentials:
"""Tests for _get_provider_credentials method."""
def test_returns_google_credentials(self):
"""Test returns Google credentials when configured."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_GOOGLE_CLIENT_ID = "google_client_id"
mock_settings.OAUTH_GOOGLE_CLIENT_SECRET = "google_secret"
client_id, secret = OAuthService._get_provider_credentials("google")
assert client_id == "google_client_id"
assert secret == "google_secret"
def test_returns_github_credentials(self):
"""Test returns GitHub credentials when configured."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_GITHUB_CLIENT_ID = "github_client_id"
mock_settings.OAUTH_GITHUB_CLIENT_SECRET = "github_secret"
client_id, secret = OAuthService._get_provider_credentials("github")
assert client_id == "github_client_id"
assert secret == "github_secret"
def test_raises_for_unknown_provider(self):
"""Test raises error for unknown provider."""
with pytest.raises(AuthenticationError, match="Unknown OAuth provider"):
OAuthService._get_provider_credentials("unknown")
def test_raises_when_credentials_not_configured(self):
"""Test raises error when credentials are not configured."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_GOOGLE_CLIENT_ID = None
mock_settings.OAUTH_GOOGLE_CLIENT_SECRET = "secret"
with pytest.raises(AuthenticationError, match="not configured"):
OAuthService._get_provider_credentials("google")
class TestCreateAuthorizationUrl:
"""Tests for create_authorization_url method."""
@pytest.mark.asyncio
async def test_raises_when_oauth_disabled(self, async_test_db):
"""Test raises error when OAuth is disabled."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = False
with pytest.raises(AuthenticationError, match="not enabled"):
await OAuthService.create_authorization_url(
session,
provider="google",
redirect_uri="http://localhost:3000/callback",
)
@pytest.mark.asyncio
async def test_raises_for_unknown_provider(self, async_test_db):
"""Test raises error for unknown provider."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
with pytest.raises(AuthenticationError, match="Unknown OAuth provider"):
await OAuthService.create_authorization_url(
session,
provider="unknown",
redirect_uri="http://localhost:3000/callback",
)
@pytest.mark.asyncio
async def test_raises_when_provider_not_enabled(self, async_test_db):
"""Test raises error when provider is not in enabled list."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["github"] # google not enabled
with pytest.raises(AuthenticationError, match="not enabled"):
await OAuthService.create_authorization_url(
session,
provider="google",
redirect_uri="http://localhost:3000/callback",
)
@pytest.mark.asyncio
async def test_creates_authorization_url_for_google(self, async_test_db):
"""Test creates authorization URL for Google with PKCE."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["google"]
mock_settings.OAUTH_GOOGLE_CLIENT_ID = "google_client_id"
mock_settings.OAUTH_GOOGLE_CLIENT_SECRET = "google_secret"
mock_settings.OAUTH_STATE_EXPIRE_MINUTES = 10
url, state = await OAuthService.create_authorization_url(
session,
provider="google",
redirect_uri="http://localhost:3000/callback",
)
assert url is not None
assert "accounts.google.com" in url
assert state is not None
assert len(state) > 20
@pytest.mark.asyncio
async def test_creates_authorization_url_for_github(self, async_test_db):
"""Test creates authorization URL for GitHub."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["github"]
mock_settings.OAUTH_GITHUB_CLIENT_ID = "github_client_id"
mock_settings.OAUTH_GITHUB_CLIENT_SECRET = "github_secret"
mock_settings.OAUTH_STATE_EXPIRE_MINUTES = 10
url, state = await OAuthService.create_authorization_url(
session,
provider="github",
redirect_uri="http://localhost:3000/callback",
)
assert url is not None
assert "github.com/login/oauth/authorize" in url
assert state is not None
class TestHandleCallback:
"""Tests for handle_callback method."""
@pytest.mark.asyncio
async def test_raises_for_invalid_state(self, async_test_db):
"""Test raises error for invalid/expired state."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(AuthenticationError, match="Invalid or expired"):
await OAuthService.handle_callback(
session,
code="auth_code",
state="invalid_state",
redirect_uri="http://localhost:3000/callback",
)
class TestUnlinkProvider:
"""Tests for unlink_provider method."""
@pytest.mark.asyncio
async def test_unlink_with_password_succeeds(self, async_test_db, async_test_user):
"""Test unlinking succeeds when user has password."""
_engine, AsyncTestingSessionLocal = async_test_db
# Create OAuth account
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_123",
)
await oauth_account.create_account(session, obj_in=account_data)
# Unlink (user has password)
async with AsyncTestingSessionLocal() as session:
# Need to get fresh user instance
from sqlalchemy import select
from app.models.user import User
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one()
success = await OAuthService.unlink_provider(
session, user=user, provider="google"
)
assert success is True
# Verify unlinked
async with AsyncTestingSessionLocal() as session:
account = await oauth_account.get_user_account_by_provider(
session, user_id=async_test_user.id, provider="google"
)
assert account is None
@pytest.mark.asyncio
async def test_unlink_not_found_raises(self, async_test_db, async_test_user):
"""Test unlinking non-existent provider raises error."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
from sqlalchemy import select
from app.models.user import User
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one()
with pytest.raises(AuthenticationError, match="No google account found"):
await OAuthService.unlink_provider(
session, user=user, provider="google"
)
@pytest.mark.asyncio
async def test_unlink_oauth_only_user_blocked(self, async_test_db):
"""Test unlinking fails for OAuth-only user with single provider."""
_engine, AsyncTestingSessionLocal = async_test_db
# Create OAuth-only user
from app.models.user import User
async with AsyncTestingSessionLocal() as session:
oauth_user = User(
id=uuid4(),
email="oauthonly@example.com",
password_hash=None, # No password
first_name="OAuth",
is_active=True,
)
session.add(oauth_user)
await session.commit()
# Link single OAuth account
account_data = OAuthAccountCreate(
user_id=oauth_user.id,
provider="google",
provider_user_id="google_only",
)
await oauth_account.create_account(session, obj_in=account_data)
# Try to unlink
async with AsyncTestingSessionLocal() as session:
from sqlalchemy import select
result = await session.execute(
select(User).where(User.email == "oauthonly@example.com")
)
user = result.scalar_one()
with pytest.raises(AuthenticationError, match="Cannot unlink"):
await OAuthService.unlink_provider(
session, user=user, provider="google"
)
@pytest.mark.asyncio
async def test_unlink_with_multiple_providers_succeeds(self, async_test_db):
"""Test unlinking succeeds when user has multiple providers."""
_engine, AsyncTestingSessionLocal = async_test_db
from app.models.user import User
# Create OAuth-only user with multiple providers
async with AsyncTestingSessionLocal() as session:
oauth_user = User(
id=uuid4(),
email="multiauth@example.com",
password_hash=None,
first_name="Multi",
is_active=True,
)
session.add(oauth_user)
await session.commit()
# Link multiple OAuth accounts
for provider in ["google", "github"]:
account_data = OAuthAccountCreate(
user_id=oauth_user.id,
provider=provider,
provider_user_id=f"{provider}_user",
)
await oauth_account.create_account(session, obj_in=account_data)
# Unlink one provider (should succeed)
async with AsyncTestingSessionLocal() as session:
from sqlalchemy import select
result = await session.execute(
select(User).where(User.email == "multiauth@example.com")
)
user = result.scalar_one()
success = await OAuthService.unlink_provider(
session, user=user, provider="google"
)
assert success is True
class TestCleanupExpiredStates:
"""Tests for cleanup_expired_states method."""
@pytest.mark.asyncio
async def test_cleanup_removes_expired_states(self, async_test_db):
"""Test cleanup removes expired states."""
_engine, AsyncTestingSessionLocal = async_test_db
# Create expired state
async with AsyncTestingSessionLocal() as session:
expired_state = OAuthStateCreate(
state="expired_cleanup_test",
provider="google",
expires_at=datetime.now(UTC) - timedelta(minutes=5),
)
await oauth_state.create_state(session, obj_in=expired_state)
# Run cleanup
async with AsyncTestingSessionLocal() as session:
count = await OAuthService.cleanup_expired_states(session)
assert count >= 1
class TestProviderConfigs:
"""Tests for provider configuration constants."""
def test_google_provider_config(self):
"""Test Google provider configuration is correct."""
config = OAUTH_PROVIDERS.get("google")
assert config is not None
assert config["name"] == "Google"
assert "accounts.google.com" in config["authorize_url"]
assert config["supports_pkce"] is True
def test_github_provider_config(self):
"""Test GitHub provider configuration is correct."""
config = OAUTH_PROVIDERS.get("github")
assert config is not None
assert config["name"] == "GitHub"
assert "github.com" in config["authorize_url"]
assert config["supports_pkce"] is False