- Extended OAuth callback tests to cover various scenarios (e.g., account linking, user creation, inactive users, and token/user info failures). - Added `app/init_db.py` to the excluded files in `pyproject.toml`.
601 lines
23 KiB
Python
601 lines
23 KiB
Python
# tests/crud/test_oauth.py
|
|
"""
|
|
Comprehensive tests for OAuth CRUD operations.
|
|
"""
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import pytest
|
|
|
|
from app.crud.oauth import oauth_account, oauth_client, oauth_state
|
|
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
|
|
|
|
|
class TestOAuthAccountCRUD:
|
|
"""Tests for OAuth account CRUD operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_account(self, async_test_db, async_test_user):
|
|
"""Test creating an OAuth account link."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
account_data = OAuthAccountCreate(
|
|
user_id=async_test_user.id,
|
|
provider="google",
|
|
provider_user_id="google_123456",
|
|
provider_email="user@gmail.com",
|
|
)
|
|
account = await oauth_account.create_account(session, obj_in=account_data)
|
|
|
|
assert account is not None
|
|
assert account.provider == "google"
|
|
assert account.provider_user_id == "google_123456"
|
|
assert account.user_id == async_test_user.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_account_same_provider_twice_fails(
|
|
self, async_test_db, async_test_user
|
|
):
|
|
"""Test creating same OAuth account for same user twice raises error."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
account_data = OAuthAccountCreate(
|
|
user_id=async_test_user.id,
|
|
provider="google",
|
|
provider_user_id="google_dup_123",
|
|
provider_email="user@gmail.com",
|
|
)
|
|
await oauth_account.create_account(session, obj_in=account_data)
|
|
|
|
# Try to create same account again (same provider + provider_user_id)
|
|
async with AsyncTestingSessionLocal() as session:
|
|
account_data2 = OAuthAccountCreate(
|
|
user_id=async_test_user.id, # Same user
|
|
provider="google",
|
|
provider_user_id="google_dup_123", # Same provider_user_id
|
|
provider_email="user@gmail.com",
|
|
)
|
|
|
|
# SQLite returns different error message than PostgreSQL
|
|
with pytest.raises(
|
|
ValueError, match="(already linked|UNIQUE constraint failed)"
|
|
):
|
|
await oauth_account.create_account(session, obj_in=account_data2)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_provider_id(self, async_test_db, async_test_user):
|
|
"""Test getting OAuth account by provider and provider user ID."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
account_data = OAuthAccountCreate(
|
|
user_id=async_test_user.id,
|
|
provider="github",
|
|
provider_user_id="github_789",
|
|
provider_email="user@github.com",
|
|
)
|
|
await oauth_account.create_account(session, obj_in=account_data)
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_account.get_by_provider_id(
|
|
session,
|
|
provider="github",
|
|
provider_user_id="github_789",
|
|
)
|
|
assert result is not None
|
|
assert result.provider == "github"
|
|
assert result.user is not None # Eager loaded
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_provider_id_not_found(self, async_test_db):
|
|
"""Test getting non-existent OAuth account returns None."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_account.get_by_provider_id(
|
|
session,
|
|
provider="google",
|
|
provider_user_id="nonexistent",
|
|
)
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_accounts(self, async_test_db, async_test_user):
|
|
"""Test getting all OAuth accounts for a user."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
# Create two accounts for the same user
|
|
for provider in ["google", "github"]:
|
|
account_data = OAuthAccountCreate(
|
|
user_id=async_test_user.id,
|
|
provider=provider,
|
|
provider_user_id=f"{provider}_user_123",
|
|
provider_email=f"user@{provider}.com",
|
|
)
|
|
await oauth_account.create_account(session, obj_in=account_data)
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
accounts = await oauth_account.get_user_accounts(
|
|
session, user_id=async_test_user.id
|
|
)
|
|
assert len(accounts) == 2
|
|
providers = {a.provider for a in accounts}
|
|
assert providers == {"google", "github"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_account_by_provider(self, async_test_db, async_test_user):
|
|
"""Test getting specific OAuth account for user and provider."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
account_data = OAuthAccountCreate(
|
|
user_id=async_test_user.id,
|
|
provider="google",
|
|
provider_user_id="google_specific",
|
|
)
|
|
await oauth_account.create_account(session, obj_in=account_data)
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_account.get_user_account_by_provider(
|
|
session,
|
|
user_id=async_test_user.id,
|
|
provider="google",
|
|
)
|
|
assert result is not None
|
|
assert result.provider == "google"
|
|
|
|
# Test not found
|
|
result2 = await oauth_account.get_user_account_by_provider(
|
|
session,
|
|
user_id=async_test_user.id,
|
|
provider="github", # Not linked
|
|
)
|
|
assert result2 is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_account(self, async_test_db, async_test_user):
|
|
"""Test deleting an OAuth account link."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
account_data = OAuthAccountCreate(
|
|
user_id=async_test_user.id,
|
|
provider="google",
|
|
provider_user_id="google_to_delete",
|
|
)
|
|
await oauth_account.create_account(session, obj_in=account_data)
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
deleted = await oauth_account.delete_account(
|
|
session,
|
|
user_id=async_test_user.id,
|
|
provider="google",
|
|
)
|
|
assert deleted is True
|
|
|
|
# Verify deletion
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_account.get_user_account_by_provider(
|
|
session,
|
|
user_id=async_test_user.id,
|
|
provider="google",
|
|
)
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_account_not_found(self, async_test_db, async_test_user):
|
|
"""Test deleting non-existent account returns False."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
deleted = await oauth_account.delete_account(
|
|
session,
|
|
user_id=async_test_user.id,
|
|
provider="nonexistent",
|
|
)
|
|
assert deleted is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_provider_email(self, async_test_db, async_test_user):
|
|
"""Test getting OAuth account by provider and email."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
account_data = OAuthAccountCreate(
|
|
user_id=async_test_user.id,
|
|
provider="google",
|
|
provider_user_id="google_email_test",
|
|
provider_email="unique@gmail.com",
|
|
)
|
|
await oauth_account.create_account(session, obj_in=account_data)
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_account.get_by_provider_email(
|
|
session,
|
|
provider="google",
|
|
email="unique@gmail.com",
|
|
)
|
|
assert result is not None
|
|
assert result.provider_email == "unique@gmail.com"
|
|
|
|
# Test not found
|
|
result2 = await oauth_account.get_by_provider_email(
|
|
session,
|
|
provider="google",
|
|
email="nonexistent@gmail.com",
|
|
)
|
|
assert result2 is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_tokens(self, async_test_db, async_test_user):
|
|
"""Test updating OAuth tokens."""
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
account_data = OAuthAccountCreate(
|
|
user_id=async_test_user.id,
|
|
provider="google",
|
|
provider_user_id="google_token_test",
|
|
)
|
|
account = await oauth_account.create_account(session, obj_in=account_data)
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
# Get the account first
|
|
account = await oauth_account.get_by_provider_id(
|
|
session, provider="google", provider_user_id="google_token_test"
|
|
)
|
|
assert account is not None
|
|
|
|
# Update tokens
|
|
new_expires = datetime.now(UTC) + timedelta(hours=1)
|
|
updated = await oauth_account.update_tokens(
|
|
session,
|
|
account=account,
|
|
access_token_encrypted="new_access_token",
|
|
refresh_token_encrypted="new_refresh_token",
|
|
token_expires_at=new_expires,
|
|
)
|
|
|
|
assert updated.access_token_encrypted == "new_access_token"
|
|
assert updated.refresh_token_encrypted == "new_refresh_token"
|
|
|
|
|
|
class TestOAuthStateCRUD:
|
|
"""Tests for OAuth state CRUD operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_state(self, async_test_db):
|
|
"""Test creating OAuth state."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
state_data = OAuthStateCreate(
|
|
state="random_state_123",
|
|
code_verifier="pkce_verifier",
|
|
nonce="oidc_nonce",
|
|
provider="google",
|
|
redirect_uri="http://localhost:3000/callback",
|
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
|
)
|
|
state = await oauth_state.create_state(session, obj_in=state_data)
|
|
|
|
assert state is not None
|
|
assert state.state == "random_state_123"
|
|
assert state.code_verifier == "pkce_verifier"
|
|
assert state.provider == "google"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_and_consume_state(self, async_test_db):
|
|
"""Test getting and consuming OAuth state."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
state_data = OAuthStateCreate(
|
|
state="consume_state_123",
|
|
provider="github",
|
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
|
)
|
|
await oauth_state.create_state(session, obj_in=state_data)
|
|
|
|
# Consume the state
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_state.get_and_consume_state(
|
|
session, state="consume_state_123"
|
|
)
|
|
assert result is not None
|
|
assert result.provider == "github"
|
|
|
|
# Try to consume again - should be None (already consumed)
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result2 = await oauth_state.get_and_consume_state(
|
|
session, state="consume_state_123"
|
|
)
|
|
assert result2 is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_and_consume_expired_state(self, async_test_db):
|
|
"""Test consuming expired state returns None."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
# Create expired state
|
|
state_data = OAuthStateCreate(
|
|
state="expired_state_123",
|
|
provider="google",
|
|
expires_at=datetime.now(UTC) - timedelta(minutes=1), # Already expired
|
|
)
|
|
await oauth_state.create_state(session, obj_in=state_data)
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_state.get_and_consume_state(
|
|
session, state="expired_state_123"
|
|
)
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_expired_states(self, async_test_db):
|
|
"""Test cleaning up expired OAuth states."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
# Create expired state
|
|
expired_state = OAuthStateCreate(
|
|
state="cleanup_expired",
|
|
provider="google",
|
|
expires_at=datetime.now(UTC) - timedelta(minutes=5),
|
|
)
|
|
await oauth_state.create_state(session, obj_in=expired_state)
|
|
|
|
# Create valid state
|
|
valid_state = OAuthStateCreate(
|
|
state="cleanup_valid",
|
|
provider="google",
|
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
|
)
|
|
await oauth_state.create_state(session, obj_in=valid_state)
|
|
|
|
# Cleanup
|
|
async with AsyncTestingSessionLocal() as session:
|
|
count = await oauth_state.cleanup_expired(session)
|
|
assert count == 1
|
|
|
|
# Verify only expired was deleted
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_state.get_and_consume_state(
|
|
session, state="cleanup_valid"
|
|
)
|
|
assert result is not None
|
|
|
|
|
|
class TestOAuthClientCRUD:
|
|
"""Tests for OAuth client CRUD operations (provider mode)."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_public_client(self, async_test_db):
|
|
"""Test creating a public OAuth client."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
client_data = OAuthClientCreate(
|
|
client_name="Test MCP App",
|
|
client_description="A test application",
|
|
redirect_uris=["http://localhost:3000/callback"],
|
|
allowed_scopes=["read:users"],
|
|
client_type="public",
|
|
)
|
|
client, secret = await oauth_client.create_client(
|
|
session, obj_in=client_data
|
|
)
|
|
|
|
assert client is not None
|
|
assert client.client_name == "Test MCP App"
|
|
assert client.client_type == "public"
|
|
assert secret is None # Public clients don't have secrets
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_confidential_client(self, async_test_db):
|
|
"""Test creating a confidential OAuth client."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
client_data = OAuthClientCreate(
|
|
client_name="Confidential App",
|
|
redirect_uris=["http://localhost:8080/callback"],
|
|
allowed_scopes=["read:users", "write:users"],
|
|
client_type="confidential",
|
|
)
|
|
client, secret = await oauth_client.create_client(
|
|
session, obj_in=client_data
|
|
)
|
|
|
|
assert client is not None
|
|
assert client.client_type == "confidential"
|
|
assert secret is not None # Confidential clients have secrets
|
|
assert len(secret) > 20 # Should be a reasonably long secret
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_client_id(self, async_test_db):
|
|
"""Test getting OAuth client by client_id."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
created_client_id = None
|
|
async with AsyncTestingSessionLocal() as session:
|
|
client_data = OAuthClientCreate(
|
|
client_name="Lookup Test",
|
|
redirect_uris=["http://localhost:3000/callback"],
|
|
allowed_scopes=["read:users"],
|
|
)
|
|
client, _ = await oauth_client.create_client(session, obj_in=client_data)
|
|
created_client_id = client.client_id
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_client.get_by_client_id(
|
|
session, client_id=created_client_id
|
|
)
|
|
assert result is not None
|
|
assert result.client_name == "Lookup Test"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_inactive_client_not_found(self, async_test_db):
|
|
"""Test getting inactive OAuth client returns None."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
created_client_id = None
|
|
async with AsyncTestingSessionLocal() as session:
|
|
client_data = OAuthClientCreate(
|
|
client_name="Inactive Client",
|
|
redirect_uris=["http://localhost:3000/callback"],
|
|
allowed_scopes=["read:users"],
|
|
)
|
|
client, _ = await oauth_client.create_client(session, obj_in=client_data)
|
|
created_client_id = client.client_id
|
|
|
|
# Deactivate
|
|
await oauth_client.deactivate_client(session, client_id=created_client_id)
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_client.get_by_client_id(
|
|
session, client_id=created_client_id
|
|
)
|
|
assert result is None # Inactive clients not returned
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_redirect_uri(self, async_test_db):
|
|
"""Test redirect URI validation."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
created_client_id = None
|
|
async with AsyncTestingSessionLocal() as session:
|
|
client_data = OAuthClientCreate(
|
|
client_name="URI Test",
|
|
redirect_uris=[
|
|
"http://localhost:3000/callback",
|
|
"http://localhost:8080/oauth",
|
|
],
|
|
allowed_scopes=["read:users"],
|
|
)
|
|
client, _ = await oauth_client.create_client(session, obj_in=client_data)
|
|
created_client_id = client.client_id
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
# Valid URI
|
|
valid = await oauth_client.validate_redirect_uri(
|
|
session,
|
|
client_id=created_client_id,
|
|
redirect_uri="http://localhost:3000/callback",
|
|
)
|
|
assert valid is True
|
|
|
|
# Invalid URI
|
|
invalid = await oauth_client.validate_redirect_uri(
|
|
session,
|
|
client_id=created_client_id,
|
|
redirect_uri="http://evil.com/callback",
|
|
)
|
|
assert invalid is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_client_secret(self, async_test_db):
|
|
"""Test client secret verification."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
created_client_id = None
|
|
created_secret = None
|
|
async with AsyncTestingSessionLocal() as session:
|
|
client_data = OAuthClientCreate(
|
|
client_name="Secret Test",
|
|
redirect_uris=["http://localhost:3000/callback"],
|
|
allowed_scopes=["read:users"],
|
|
client_type="confidential",
|
|
)
|
|
client, secret = await oauth_client.create_client(
|
|
session, obj_in=client_data
|
|
)
|
|
created_client_id = client.client_id
|
|
created_secret = secret
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
# Valid secret
|
|
valid = await oauth_client.verify_client_secret(
|
|
session,
|
|
client_id=created_client_id,
|
|
client_secret=created_secret,
|
|
)
|
|
assert valid is True
|
|
|
|
# Invalid secret
|
|
invalid = await oauth_client.verify_client_secret(
|
|
session,
|
|
client_id=created_client_id,
|
|
client_secret="wrong_secret",
|
|
)
|
|
assert invalid is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deactivate_nonexistent_client(self, async_test_db):
|
|
"""Test deactivating non-existent client returns None."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
result = await oauth_client.deactivate_client(
|
|
session, client_id="nonexistent_client_id"
|
|
)
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_redirect_uri_nonexistent_client(self, async_test_db):
|
|
"""Test validate_redirect_uri returns False for non-existent client."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
valid = await oauth_client.validate_redirect_uri(
|
|
session,
|
|
client_id="nonexistent_client_id",
|
|
redirect_uri="http://localhost:3000/callback",
|
|
)
|
|
assert valid is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_secret_nonexistent_client(self, async_test_db):
|
|
"""Test verify_client_secret returns False for non-existent client."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
valid = await oauth_client.verify_client_secret(
|
|
session,
|
|
client_id="nonexistent_client_id",
|
|
client_secret="any_secret",
|
|
)
|
|
assert valid is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_secret_public_client(self, async_test_db):
|
|
"""Test verify_client_secret returns False for public client (no secret)."""
|
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
client_data = OAuthClientCreate(
|
|
client_name="Public Client",
|
|
redirect_uris=["http://localhost:3000/callback"],
|
|
allowed_scopes=["read:users"],
|
|
client_type="public", # Public client - no secret
|
|
)
|
|
client, secret = await oauth_client.create_client(
|
|
session, obj_in=client_data
|
|
)
|
|
assert secret is None
|
|
|
|
async with AsyncTestingSessionLocal() as session:
|
|
# Public clients don't have secrets, so verification should fail
|
|
valid = await oauth_client.verify_client_secret(
|
|
session,
|
|
client_id=client.client_id,
|
|
client_secret="any_secret",
|
|
)
|
|
assert valid is False
|