Initial implementation of OAuth models, endpoints, and migrations
- Added models for `OAuthClient`, `OAuthState`, and `OAuthAccount`. - Created Pydantic schemas to support OAuth flows, client management, and linked accounts. - Implemented skeleton endpoints for OAuth Provider mode: authorization, token, and revocation. - Updated router imports to include new `/oauth` and `/oauth/provider` routes. - Added Alembic migration script to create OAuth-related database tables. - Enhanced `users` table to allow OAuth-only accounts by making `password_hash` nullable.
This commit is contained in:
403
backend/tests/services/test_oauth_service.py
Normal file
403
backend/tests/services/test_oauth_service.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# tests/services/test_oauth_service.py
|
||||
"""
|
||||
Tests for OAuthService covering authorization URL creation,
|
||||
callback handling, and account management.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.crud.oauth import oauth_account, oauth_state
|
||||
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
|
||||
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService
|
||||
|
||||
|
||||
class TestGetEnabledProviders:
|
||||
"""Tests for get_enabled_providers method."""
|
||||
|
||||
def test_returns_empty_when_disabled(self):
|
||||
"""Test returns empty providers when OAuth is disabled."""
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENABLED = False
|
||||
mock_settings.enabled_oauth_providers = []
|
||||
|
||||
result = OAuthService.get_enabled_providers()
|
||||
|
||||
assert result.enabled is False
|
||||
assert result.providers == []
|
||||
|
||||
def test_returns_configured_providers(self):
|
||||
"""Test returns configured providers when enabled."""
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENABLED = True
|
||||
mock_settings.enabled_oauth_providers = ["google", "github"]
|
||||
|
||||
result = OAuthService.get_enabled_providers()
|
||||
|
||||
assert result.enabled is True
|
||||
assert len(result.providers) == 2
|
||||
provider_names = [p.provider for p in result.providers]
|
||||
assert "google" in provider_names
|
||||
assert "github" in provider_names
|
||||
|
||||
def test_filters_unknown_providers(self):
|
||||
"""Test filters out unknown providers from list."""
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENABLED = True
|
||||
mock_settings.enabled_oauth_providers = ["google", "unknown_provider"]
|
||||
|
||||
result = OAuthService.get_enabled_providers()
|
||||
|
||||
assert result.enabled is True
|
||||
assert len(result.providers) == 1
|
||||
assert result.providers[0].provider == "google"
|
||||
|
||||
|
||||
class TestGetProviderCredentials:
|
||||
"""Tests for _get_provider_credentials method."""
|
||||
|
||||
def test_returns_google_credentials(self):
|
||||
"""Test returns Google credentials when configured."""
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_GOOGLE_CLIENT_ID = "google_client_id"
|
||||
mock_settings.OAUTH_GOOGLE_CLIENT_SECRET = "google_secret"
|
||||
|
||||
client_id, secret = OAuthService._get_provider_credentials("google")
|
||||
|
||||
assert client_id == "google_client_id"
|
||||
assert secret == "google_secret"
|
||||
|
||||
def test_returns_github_credentials(self):
|
||||
"""Test returns GitHub credentials when configured."""
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_GITHUB_CLIENT_ID = "github_client_id"
|
||||
mock_settings.OAUTH_GITHUB_CLIENT_SECRET = "github_secret"
|
||||
|
||||
client_id, secret = OAuthService._get_provider_credentials("github")
|
||||
|
||||
assert client_id == "github_client_id"
|
||||
assert secret == "github_secret"
|
||||
|
||||
def test_raises_for_unknown_provider(self):
|
||||
"""Test raises error for unknown provider."""
|
||||
with pytest.raises(AuthenticationError, match="Unknown OAuth provider"):
|
||||
OAuthService._get_provider_credentials("unknown")
|
||||
|
||||
def test_raises_when_credentials_not_configured(self):
|
||||
"""Test raises error when credentials are not configured."""
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_GOOGLE_CLIENT_ID = None
|
||||
mock_settings.OAUTH_GOOGLE_CLIENT_SECRET = "secret"
|
||||
|
||||
with pytest.raises(AuthenticationError, match="not configured"):
|
||||
OAuthService._get_provider_credentials("google")
|
||||
|
||||
|
||||
class TestCreateAuthorizationUrl:
|
||||
"""Tests for create_authorization_url method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_oauth_disabled(self, async_test_db):
|
||||
"""Test raises error when OAuth is disabled."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENABLED = False
|
||||
|
||||
with pytest.raises(AuthenticationError, match="not enabled"):
|
||||
await OAuthService.create_authorization_url(
|
||||
session,
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_for_unknown_provider(self, async_test_db):
|
||||
"""Test raises error for unknown provider."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENABLED = True
|
||||
|
||||
with pytest.raises(AuthenticationError, match="Unknown OAuth provider"):
|
||||
await OAuthService.create_authorization_url(
|
||||
session,
|
||||
provider="unknown",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_provider_not_enabled(self, async_test_db):
|
||||
"""Test raises error when provider is not in enabled list."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENABLED = True
|
||||
mock_settings.enabled_oauth_providers = ["github"] # google not enabled
|
||||
|
||||
with pytest.raises(AuthenticationError, match="not enabled"):
|
||||
await OAuthService.create_authorization_url(
|
||||
session,
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_authorization_url_for_google(self, async_test_db):
|
||||
"""Test creates authorization URL for Google with PKCE."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENABLED = True
|
||||
mock_settings.enabled_oauth_providers = ["google"]
|
||||
mock_settings.OAUTH_GOOGLE_CLIENT_ID = "google_client_id"
|
||||
mock_settings.OAUTH_GOOGLE_CLIENT_SECRET = "google_secret"
|
||||
mock_settings.OAUTH_STATE_EXPIRE_MINUTES = 10
|
||||
|
||||
url, state = await OAuthService.create_authorization_url(
|
||||
session,
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
)
|
||||
|
||||
assert url is not None
|
||||
assert "accounts.google.com" in url
|
||||
assert state is not None
|
||||
assert len(state) > 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_authorization_url_for_github(self, async_test_db):
|
||||
"""Test creates authorization URL for GitHub."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with patch("app.services.oauth_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENABLED = True
|
||||
mock_settings.enabled_oauth_providers = ["github"]
|
||||
mock_settings.OAUTH_GITHUB_CLIENT_ID = "github_client_id"
|
||||
mock_settings.OAUTH_GITHUB_CLIENT_SECRET = "github_secret"
|
||||
mock_settings.OAUTH_STATE_EXPIRE_MINUTES = 10
|
||||
|
||||
url, state = await OAuthService.create_authorization_url(
|
||||
session,
|
||||
provider="github",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
)
|
||||
|
||||
assert url is not None
|
||||
assert "github.com/login/oauth/authorize" in url
|
||||
assert state is not None
|
||||
|
||||
|
||||
class TestHandleCallback:
|
||||
"""Tests for handle_callback method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_for_invalid_state(self, async_test_db):
|
||||
"""Test raises error for invalid/expired state."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(AuthenticationError, match="Invalid or expired"):
|
||||
await OAuthService.handle_callback(
|
||||
session,
|
||||
code="auth_code",
|
||||
state="invalid_state",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
)
|
||||
|
||||
|
||||
class TestUnlinkProvider:
|
||||
"""Tests for unlink_provider method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlink_with_password_succeeds(self, async_test_db, async_test_user):
|
||||
"""Test unlinking succeeds when user has password."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create OAuth account
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
provider_user_id="google_123",
|
||||
)
|
||||
await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
# Unlink (user has password)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Need to get fresh user instance
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one()
|
||||
|
||||
success = await OAuthService.unlink_provider(
|
||||
session, user=user, provider="google"
|
||||
)
|
||||
assert success is True
|
||||
|
||||
# Verify unlinked
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account = await oauth_account.get_user_account_by_provider(
|
||||
session, user_id=async_test_user.id, provider="google"
|
||||
)
|
||||
assert account is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlink_not_found_raises(self, async_test_db, async_test_user):
|
||||
"""Test unlinking non-existent provider raises error."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one()
|
||||
|
||||
with pytest.raises(AuthenticationError, match="No google account found"):
|
||||
await OAuthService.unlink_provider(
|
||||
session, user=user, provider="google"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlink_oauth_only_user_blocked(self, async_test_db):
|
||||
"""Test unlinking fails for OAuth-only user with single provider."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create OAuth-only user
|
||||
from app.models.user import User
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
oauth_user = User(
|
||||
id=uuid4(),
|
||||
email="oauthonly@example.com",
|
||||
password_hash=None, # No password
|
||||
first_name="OAuth",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(oauth_user)
|
||||
await session.commit()
|
||||
|
||||
# Link single OAuth account
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=oauth_user.id,
|
||||
provider="google",
|
||||
provider_user_id="google_only",
|
||||
)
|
||||
await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
# Try to unlink
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.email == "oauthonly@example.com")
|
||||
)
|
||||
user = result.scalar_one()
|
||||
|
||||
with pytest.raises(AuthenticationError, match="Cannot unlink"):
|
||||
await OAuthService.unlink_provider(
|
||||
session, user=user, provider="google"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlink_with_multiple_providers_succeeds(self, async_test_db):
|
||||
"""Test unlinking succeeds when user has multiple providers."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
# Create OAuth-only user with multiple providers
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
oauth_user = User(
|
||||
id=uuid4(),
|
||||
email="multiauth@example.com",
|
||||
password_hash=None,
|
||||
first_name="Multi",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(oauth_user)
|
||||
await session.commit()
|
||||
|
||||
# Link multiple OAuth accounts
|
||||
for provider in ["google", "github"]:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=oauth_user.id,
|
||||
provider=provider,
|
||||
provider_user_id=f"{provider}_user",
|
||||
)
|
||||
await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
# Unlink one provider (should succeed)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.email == "multiauth@example.com")
|
||||
)
|
||||
user = result.scalar_one()
|
||||
|
||||
success = await OAuthService.unlink_provider(
|
||||
session, user=user, provider="google"
|
||||
)
|
||||
assert success is True
|
||||
|
||||
|
||||
class TestCleanupExpiredStates:
|
||||
"""Tests for cleanup_expired_states method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_removes_expired_states(self, async_test_db):
|
||||
"""Test cleanup removes expired states."""
|
||||
_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create expired state
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
expired_state = OAuthStateCreate(
|
||||
state="expired_cleanup_test",
|
||||
provider="google",
|
||||
expires_at=datetime.now(UTC) - timedelta(minutes=5),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=expired_state)
|
||||
|
||||
# Run cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await OAuthService.cleanup_expired_states(session)
|
||||
assert count >= 1
|
||||
|
||||
|
||||
class TestProviderConfigs:
|
||||
"""Tests for provider configuration constants."""
|
||||
|
||||
def test_google_provider_config(self):
|
||||
"""Test Google provider configuration is correct."""
|
||||
config = OAUTH_PROVIDERS.get("google")
|
||||
assert config is not None
|
||||
assert config["name"] == "Google"
|
||||
assert "accounts.google.com" in config["authorize_url"]
|
||||
assert config["supports_pkce"] is True
|
||||
|
||||
def test_github_provider_config(self):
|
||||
"""Test GitHub provider configuration is correct."""
|
||||
config = OAUTH_PROVIDERS.get("github")
|
||||
assert config is not None
|
||||
assert config["name"] == "GitHub"
|
||||
assert "github.com" in config["authorize_url"]
|
||||
assert config["supports_pkce"] is False
|
||||
Reference in New Issue
Block a user