Files
fast-next-template/backend/tests/crud/test_oauth.py
Felipe Cardoso 16ee4e0cb3 Initial implementation of OAuth models, endpoints, and migrations
- Added models for `OAuthClient`, `OAuthState`, and `OAuthAccount`.
- Created Pydantic schemas to support OAuth flows, client management, and linked accounts.
- Implemented skeleton endpoints for OAuth Provider mode: authorization, token, and revocation.
- Updated router imports to include new `/oauth` and `/oauth/provider` routes.
- Added Alembic migration script to create OAuth-related database tables.
- Enhanced `users` table to allow OAuth-only accounts by making `password_hash` nullable.
2025-11-25 00:37:23 +01:00

538 lines
21 KiB
Python

# tests/crud/test_oauth.py
"""
Comprehensive tests for OAuth CRUD operations.
"""
from datetime import UTC, datetime, timedelta
import pytest
from app.crud.oauth import oauth_account, oauth_client, oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
class TestOAuthAccountCRUD:
"""Tests for OAuth account CRUD operations."""
@pytest.mark.asyncio
async def test_create_account(self, async_test_db, async_test_user):
"""Test creating an OAuth account link."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_123456",
provider_email="user@gmail.com",
)
account = await oauth_account.create_account(session, obj_in=account_data)
assert account is not None
assert account.provider == "google"
assert account.provider_user_id == "google_123456"
assert account.user_id == async_test_user.id
@pytest.mark.asyncio
async def test_create_account_same_provider_twice_fails(
self, async_test_db, async_test_user
):
"""Test creating same OAuth account for same user twice raises error."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_dup_123",
provider_email="user@gmail.com",
)
await oauth_account.create_account(session, obj_in=account_data)
# Try to create same account again (same provider + provider_user_id)
async with AsyncTestingSessionLocal() as session:
account_data2 = OAuthAccountCreate(
user_id=async_test_user.id, # Same user
provider="google",
provider_user_id="google_dup_123", # Same provider_user_id
provider_email="user@gmail.com",
)
# SQLite returns different error message than PostgreSQL
with pytest.raises(
ValueError, match="(already linked|UNIQUE constraint failed)"
):
await oauth_account.create_account(session, obj_in=account_data2)
@pytest.mark.asyncio
async def test_get_by_provider_id(self, async_test_db, async_test_user):
"""Test getting OAuth account by provider and provider user ID."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="github",
provider_user_id="github_789",
provider_email="user@github.com",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_by_provider_id(
session,
provider="github",
provider_user_id="github_789",
)
assert result is not None
assert result.provider == "github"
assert result.user is not None # Eager loaded
@pytest.mark.asyncio
async def test_get_by_provider_id_not_found(self, async_test_db):
"""Test getting non-existent OAuth account returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_by_provider_id(
session,
provider="google",
provider_user_id="nonexistent",
)
assert result is None
@pytest.mark.asyncio
async def test_get_user_accounts(self, async_test_db, async_test_user):
"""Test getting all OAuth accounts for a user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create two accounts for the same user
for provider in ["google", "github"]:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider=provider,
provider_user_id=f"{provider}_user_123",
provider_email=f"user@{provider}.com",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
accounts = await oauth_account.get_user_accounts(
session, user_id=async_test_user.id
)
assert len(accounts) == 2
providers = {a.provider for a in accounts}
assert providers == {"google", "github"}
@pytest.mark.asyncio
async def test_get_user_account_by_provider(self, async_test_db, async_test_user):
"""Test getting specific OAuth account for user and provider."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_specific",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_user_account_by_provider(
session,
user_id=async_test_user.id,
provider="google",
)
assert result is not None
assert result.provider == "google"
# Test not found
result2 = await oauth_account.get_user_account_by_provider(
session,
user_id=async_test_user.id,
provider="github", # Not linked
)
assert result2 is None
@pytest.mark.asyncio
async def test_delete_account(self, async_test_db, async_test_user):
"""Test deleting an OAuth account link."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_to_delete",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
deleted = await oauth_account.delete_account(
session,
user_id=async_test_user.id,
provider="google",
)
assert deleted is True
# Verify deletion
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_user_account_by_provider(
session,
user_id=async_test_user.id,
provider="google",
)
assert result is None
@pytest.mark.asyncio
async def test_delete_account_not_found(self, async_test_db, async_test_user):
"""Test deleting non-existent account returns False."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
deleted = await oauth_account.delete_account(
session,
user_id=async_test_user.id,
provider="nonexistent",
)
assert deleted is False
@pytest.mark.asyncio
async def test_get_by_provider_email(self, async_test_db, async_test_user):
"""Test getting OAuth account by provider and email."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_email_test",
provider_email="unique@gmail.com",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_by_provider_email(
session,
provider="google",
email="unique@gmail.com",
)
assert result is not None
assert result.provider_email == "unique@gmail.com"
# Test not found
result2 = await oauth_account.get_by_provider_email(
session,
provider="google",
email="nonexistent@gmail.com",
)
assert result2 is None
@pytest.mark.asyncio
async def test_update_tokens(self, async_test_db, async_test_user):
"""Test updating OAuth tokens."""
from datetime import UTC, datetime, timedelta
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_token_test",
)
account = await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
# Get the account first
account = await oauth_account.get_by_provider_id(
session, provider="google", provider_user_id="google_token_test"
)
assert account is not None
# Update tokens
new_expires = datetime.now(UTC) + timedelta(hours=1)
updated = await oauth_account.update_tokens(
session,
account=account,
access_token_encrypted="new_access_token",
refresh_token_encrypted="new_refresh_token",
token_expires_at=new_expires,
)
assert updated.access_token_encrypted == "new_access_token"
assert updated.refresh_token_encrypted == "new_refresh_token"
class TestOAuthStateCRUD:
"""Tests for OAuth state CRUD operations."""
@pytest.mark.asyncio
async def test_create_state(self, async_test_db):
"""Test creating OAuth state."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
state_data = OAuthStateCreate(
state="random_state_123",
code_verifier="pkce_verifier",
nonce="oidc_nonce",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
state = await oauth_state.create_state(session, obj_in=state_data)
assert state is not None
assert state.state == "random_state_123"
assert state.code_verifier == "pkce_verifier"
assert state.provider == "google"
@pytest.mark.asyncio
async def test_get_and_consume_state(self, async_test_db):
"""Test getting and consuming OAuth state."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
state_data = OAuthStateCreate(
state="consume_state_123",
provider="github",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
# Consume the state
async with AsyncTestingSessionLocal() as session:
result = await oauth_state.get_and_consume_state(
session, state="consume_state_123"
)
assert result is not None
assert result.provider == "github"
# Try to consume again - should be None (already consumed)
async with AsyncTestingSessionLocal() as session:
result2 = await oauth_state.get_and_consume_state(
session, state="consume_state_123"
)
assert result2 is None
@pytest.mark.asyncio
async def test_get_and_consume_expired_state(self, async_test_db):
"""Test consuming expired state returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create expired state
state_data = OAuthStateCreate(
state="expired_state_123",
provider="google",
expires_at=datetime.now(UTC) - timedelta(minutes=1), # Already expired
)
await oauth_state.create_state(session, obj_in=state_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_state.get_and_consume_state(
session, state="expired_state_123"
)
assert result is None
@pytest.mark.asyncio
async def test_cleanup_expired_states(self, async_test_db):
"""Test cleaning up expired OAuth states."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create expired state
expired_state = OAuthStateCreate(
state="cleanup_expired",
provider="google",
expires_at=datetime.now(UTC) - timedelta(minutes=5),
)
await oauth_state.create_state(session, obj_in=expired_state)
# Create valid state
valid_state = OAuthStateCreate(
state="cleanup_valid",
provider="google",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=valid_state)
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await oauth_state.cleanup_expired(session)
assert count == 1
# Verify only expired was deleted
async with AsyncTestingSessionLocal() as session:
result = await oauth_state.get_and_consume_state(
session, state="cleanup_valid"
)
assert result is not None
class TestOAuthClientCRUD:
"""Tests for OAuth client CRUD operations (provider mode)."""
@pytest.mark.asyncio
async def test_create_public_client(self, async_test_db):
"""Test creating a public OAuth client."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Test MCP App",
client_description="A test application",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
client_type="public",
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
assert client is not None
assert client.client_name == "Test MCP App"
assert client.client_type == "public"
assert secret is None # Public clients don't have secrets
@pytest.mark.asyncio
async def test_create_confidential_client(self, async_test_db):
"""Test creating a confidential OAuth client."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Confidential App",
redirect_uris=["http://localhost:8080/callback"],
allowed_scopes=["read:users", "write:users"],
client_type="confidential",
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
assert client is not None
assert client.client_type == "confidential"
assert secret is not None # Confidential clients have secrets
assert len(secret) > 20 # Should be a reasonably long secret
@pytest.mark.asyncio
async def test_get_by_client_id(self, async_test_db):
"""Test getting OAuth client by client_id."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Lookup Test",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
)
client, _ = await oauth_client.create_client(session, obj_in=client_data)
created_client_id = client.client_id
async with AsyncTestingSessionLocal() as session:
result = await oauth_client.get_by_client_id(
session, client_id=created_client_id
)
assert result is not None
assert result.client_name == "Lookup Test"
@pytest.mark.asyncio
async def test_get_inactive_client_not_found(self, async_test_db):
"""Test getting inactive OAuth client returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Inactive Client",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
)
client, _ = await oauth_client.create_client(session, obj_in=client_data)
created_client_id = client.client_id
# Deactivate
await oauth_client.deactivate_client(session, client_id=created_client_id)
async with AsyncTestingSessionLocal() as session:
result = await oauth_client.get_by_client_id(
session, client_id=created_client_id
)
assert result is None # Inactive clients not returned
@pytest.mark.asyncio
async def test_validate_redirect_uri(self, async_test_db):
"""Test redirect URI validation."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="URI Test",
redirect_uris=[
"http://localhost:3000/callback",
"http://localhost:8080/oauth",
],
allowed_scopes=["read:users"],
)
client, _ = await oauth_client.create_client(session, obj_in=client_data)
created_client_id = client.client_id
async with AsyncTestingSessionLocal() as session:
# Valid URI
valid = await oauth_client.validate_redirect_uri(
session,
client_id=created_client_id,
redirect_uri="http://localhost:3000/callback",
)
assert valid is True
# Invalid URI
invalid = await oauth_client.validate_redirect_uri(
session,
client_id=created_client_id,
redirect_uri="http://evil.com/callback",
)
assert invalid is False
@pytest.mark.asyncio
async def test_verify_client_secret(self, async_test_db):
"""Test client secret verification."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
created_secret = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Secret Test",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
client_type="confidential",
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
created_client_id = client.client_id
created_secret = secret
async with AsyncTestingSessionLocal() as session:
# Valid secret
valid = await oauth_client.verify_client_secret(
session,
client_id=created_client_id,
client_secret=created_secret,
)
assert valid is True
# Invalid secret
invalid = await oauth_client.verify_client_secret(
session,
client_id=created_client_id,
client_secret="wrong_secret",
)
assert invalid is False