Files
fast-next-template/backend/tests/services/test_oauth_provider_service.py
Felipe Cardoso 0ea428b718 Refactor tests for improved readability and fixture consistency
- Reformatted headers in E2E tests to improve readability and ensure consistent style.
- Updated confidential client fixture to use bcrypt for secret hashing, enhancing security and testing backward compatibility with legacy SHA-256 hashes.
- Added new test cases for PKCE verification, rejecting insecure 'plain' methods, and improved error handling.
- Refined session workflows and user agent handling in E2E tests for session management.
- Consolidated schema operation tests and fixed minor formatting inconsistencies.
2025-11-26 00:13:53 +01:00

773 lines
28 KiB
Python

# tests/services/test_oauth_provider_service.py
"""
Tests for OAuth Provider Service (Authorization Server mode for MCP).
Covers:
- Authorization code creation and exchange
- Token generation, refresh, and revocation
- PKCE verification
- Token introspection (RFC 7662)
- Consent management
- Error handling
"""
import base64
import hashlib
import secrets
from unittest.mock import patch
from uuid import uuid4
import pytest
import pytest_asyncio
from app.models.oauth_client import OAuthClient
from app.models.user import User
from app.services import oauth_provider_service as service
from app.utils.test_utils import setup_async_test_db, teardown_async_test_db
@pytest_asyncio.fixture(scope="function")
async def db():
"""Fixture provides testing engine and session for each test."""
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
async with AsyncTestingSessionLocal() as session:
yield session
await teardown_async_test_db(test_engine)
@pytest_asyncio.fixture
async def test_user(db):
"""Create a test user."""
user = User(
id=uuid4(),
email="testuser@example.com",
password_hash="$2b$12$test",
first_name="Test",
last_name="User",
is_active=True,
is_superuser=False,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
@pytest_asyncio.fixture
async def public_client(db):
"""Create a test public OAuth client."""
client = OAuthClient(
id=uuid4(),
client_id="test_public_client",
client_name="Test Public Client",
client_type="public",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["openid", "profile", "email", "read:users"],
is_active=True,
)
db.add(client)
await db.commit()
await db.refresh(client)
return client
@pytest_asyncio.fixture
async def confidential_client(db):
"""Create a test confidential OAuth client using bcrypt."""
from app.core.auth import get_password_hash
secret = "test_client_secret"
# Use bcrypt for new client secret hashing (security improvement)
secret_hash = get_password_hash(secret)
client = OAuthClient(
id=uuid4(),
client_id="test_confidential_client",
client_name="Test Confidential Client",
client_type="confidential",
client_secret_hash=secret_hash,
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["openid", "profile", "email"],
is_active=True,
)
db.add(client)
await db.commit()
await db.refresh(client)
return client, secret
@pytest_asyncio.fixture
async def confidential_client_legacy_hash(db):
"""Create a test confidential OAuth client with legacy SHA-256 hash."""
# This tests backward compatibility with old SHA-256 hashed secrets
secret = "test_legacy_secret"
secret_hash = hashlib.sha256(secret.encode()).hexdigest()
client = OAuthClient(
id=uuid4(),
client_id="test_legacy_client",
client_name="Test Legacy Client",
client_type="confidential",
client_secret_hash=secret_hash,
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["openid", "profile"],
is_active=True,
)
db.add(client)
await db.commit()
await db.refresh(client)
return client, secret
class TestHelperFunctions:
"""Tests for helper functions."""
def test_generate_code_length(self):
"""Test authorization code generation has proper length."""
code = service.generate_code()
assert len(code) > 64 # Base64 encoding of 64 bytes
def test_generate_code_unique(self):
"""Test authorization codes are unique."""
codes = [service.generate_code() for _ in range(100)]
assert len(set(codes)) == 100
def test_generate_token(self):
"""Test token generation."""
token = service.generate_token()
assert len(token) > 32
def test_generate_jti(self):
"""Test JTI generation."""
jti = service.generate_jti()
assert len(jti) > 20
def test_hash_token(self):
"""Test token hashing."""
token = "test_token"
hashed = service.hash_token(token)
assert len(hashed) == 64 # SHA-256 hex digest
def test_hash_token_deterministic(self):
"""Test same token produces same hash."""
token = "test_token"
hash1 = service.hash_token(token)
hash2 = service.hash_token(token)
assert hash1 == hash2
def test_parse_scope(self):
"""Test scope parsing."""
assert service.parse_scope("openid profile email") == [
"openid",
"profile",
"email",
]
assert service.parse_scope("") == []
assert service.parse_scope(" openid profile ") == ["openid", "profile"]
def test_join_scope(self):
"""Test scope joining."""
# Result is sorted and deduplicated
result = service.join_scope(["profile", "openid", "profile"])
assert "openid" in result
assert "profile" in result
class TestPKCEVerification:
"""Tests for PKCE verification."""
def test_verify_pkce_s256_valid(self):
"""Test PKCE verification with S256 method."""
# Generate code_verifier
code_verifier = secrets.token_urlsafe(64)
# Generate code_challenge using S256
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
assert service.verify_pkce(code_verifier, code_challenge, "S256") is True
def test_verify_pkce_s256_invalid(self):
"""Test PKCE verification fails with wrong verifier."""
code_verifier = secrets.token_urlsafe(64)
wrong_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
assert service.verify_pkce(wrong_verifier, code_challenge, "S256") is False
def test_verify_pkce_plain_rejected(self):
"""Test PKCE verification rejects 'plain' method for security."""
# SECURITY: 'plain' method provides no security benefit and must be rejected
# per RFC 7636 Section 4.3 - only S256 is allowed
code_verifier = "test_verifier"
assert service.verify_pkce(code_verifier, code_verifier, "plain") is False
def test_verify_pkce_unknown_method(self):
"""Test PKCE verification with unknown method returns False."""
assert service.verify_pkce("verifier", "challenge", "unknown") is False
class TestClientValidation:
"""Tests for client validation."""
@pytest.mark.asyncio
async def test_get_client_success(self, db, public_client):
"""Test getting a valid client."""
client = await service.get_client(db, public_client.client_id)
assert client is not None
assert client.client_id == public_client.client_id
@pytest.mark.asyncio
async def test_get_client_not_found(self, db):
"""Test getting a non-existent client."""
client = await service.get_client(db, "nonexistent")
assert client is None
@pytest.mark.asyncio
async def test_get_client_inactive(self, db, public_client):
"""Test getting an inactive client returns None."""
public_client.is_active = False
await db.commit()
client = await service.get_client(db, public_client.client_id)
assert client is None
@pytest.mark.asyncio
async def test_validate_client_public(self, db, public_client):
"""Test validating a public client."""
client = await service.validate_client(db, public_client.client_id)
assert client.client_id == public_client.client_id
@pytest.mark.asyncio
async def test_validate_client_confidential_with_secret(
self, db, confidential_client
):
"""Test validating a confidential client with correct secret."""
client, secret = confidential_client
validated = await service.validate_client(db, client.client_id, secret)
assert validated.client_id == client.client_id
@pytest.mark.asyncio
async def test_validate_client_confidential_wrong_secret(
self, db, confidential_client
):
"""Test validating a confidential client with wrong secret."""
client, _ = confidential_client
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
await service.validate_client(db, client.client_id, "wrong_secret")
@pytest.mark.asyncio
async def test_validate_client_confidential_no_secret(
self, db, confidential_client
):
"""Test validating a confidential client without secret."""
client, _ = confidential_client
with pytest.raises(service.InvalidClientError, match="Client secret required"):
await service.validate_client(db, client.client_id)
@pytest.mark.asyncio
async def test_validate_client_legacy_sha256_hash(
self, db, confidential_client_legacy_hash
):
"""Test validating a client with legacy SHA-256 hash (backward compatibility)."""
client, secret = confidential_client_legacy_hash
validated = await service.validate_client(db, client.client_id, secret)
assert validated.client_id == client.client_id
@pytest.mark.asyncio
async def test_validate_client_legacy_sha256_wrong_secret(
self, db, confidential_client_legacy_hash
):
"""Test legacy SHA-256 client rejects wrong secret."""
client, _ = confidential_client_legacy_hash
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
await service.validate_client(db, client.client_id, "wrong_secret")
def test_validate_redirect_uri_success(self, public_client):
"""Test validating a registered redirect URI."""
# Should not raise
service.validate_redirect_uri(public_client, "http://localhost:3000/callback")
def test_validate_redirect_uri_invalid(self, public_client):
"""Test validating an unregistered redirect URI."""
with pytest.raises(service.InvalidRequestError, match="Invalid redirect_uri"):
service.validate_redirect_uri(public_client, "http://evil.com/callback")
def test_validate_redirect_uri_no_uris(self, public_client):
"""Test validating when client has no URIs."""
public_client.redirect_uris = []
with pytest.raises(service.InvalidRequestError, match="no registered"):
service.validate_redirect_uri(public_client, "http://localhost:3000")
class TestScopeValidation:
"""Tests for scope validation."""
def test_validate_scopes_all_valid(self, public_client):
"""Test validating all valid scopes."""
scopes = service.validate_scopes(public_client, ["openid", "profile"])
assert "openid" in scopes
assert "profile" in scopes
def test_validate_scopes_partial_valid(self, public_client):
"""Test validating with some invalid scopes - filters them out."""
scopes = service.validate_scopes(public_client, ["openid", "invalid_scope"])
assert "openid" in scopes
assert "invalid_scope" not in scopes
def test_validate_scopes_empty_uses_all_allowed(self, public_client):
"""Test empty scope request uses all allowed scopes."""
scopes = service.validate_scopes(public_client, [])
assert set(scopes) == set(public_client.allowed_scopes)
def test_validate_scopes_none_valid(self, public_client):
"""Test validating with no valid scopes raises error."""
with pytest.raises(service.InvalidScopeError):
service.validate_scopes(public_client, ["invalid1", "invalid2"])
class TestAuthorizationCode:
"""Tests for authorization code creation and exchange."""
@pytest.mark.asyncio
async def test_create_authorization_code_public_with_pkce(
self, db, public_client, test_user
):
"""Test creating authorization code for public client with PKCE."""
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid profile",
code_challenge=code_challenge,
code_challenge_method="S256",
)
assert code is not None
assert len(code) > 64
@pytest.mark.asyncio
async def test_create_authorization_code_public_without_pkce_fails(
self, db, public_client, test_user
):
"""Test creating authorization code for public client without PKCE fails."""
with pytest.raises(service.InvalidRequestError, match="PKCE"):
await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid",
)
@pytest.mark.asyncio
async def test_exchange_authorization_code_success(
self, db, public_client, test_user
):
"""Test exchanging valid authorization code for tokens."""
# Create PKCE challenge
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
# Create auth code
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid profile",
code_challenge=code_challenge,
code_challenge_method="S256",
)
# Exchange code
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.exchange_authorization_code(
db=db,
code=code,
client_id=public_client.client_id,
redirect_uri="http://localhost:3000/callback",
code_verifier=code_verifier,
)
assert "access_token" in result
assert "refresh_token" in result
assert result["token_type"] == "Bearer"
assert "expires_in" in result
@pytest.mark.asyncio
async def test_exchange_authorization_code_invalid_code(self, db, public_client):
"""Test exchanging invalid code fails."""
with pytest.raises(service.InvalidGrantError, match="Invalid authorization"):
await service.exchange_authorization_code(
db=db,
code="invalid_code",
client_id=public_client.client_id,
redirect_uri="http://localhost:3000/callback",
)
@pytest.mark.asyncio
async def test_exchange_authorization_code_wrong_redirect_uri(
self, db, public_client, test_user
):
"""Test exchanging code with wrong redirect_uri fails."""
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid",
code_challenge=code_challenge,
code_challenge_method="S256",
)
with pytest.raises(service.InvalidGrantError, match="redirect_uri mismatch"):
await service.exchange_authorization_code(
db=db,
code=code,
client_id=public_client.client_id,
redirect_uri="http://different.com/callback",
code_verifier=code_verifier,
)
@pytest.mark.asyncio
async def test_exchange_authorization_code_invalid_pkce(
self, db, public_client, test_user
):
"""Test exchanging code with invalid PKCE verifier fails."""
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid",
code_challenge=code_challenge,
code_challenge_method="S256",
)
with pytest.raises(service.InvalidGrantError, match="Invalid code_verifier"):
await service.exchange_authorization_code(
db=db,
code=code,
client_id=public_client.client_id,
redirect_uri="http://localhost:3000/callback",
code_verifier="wrong_verifier",
)
class TestTokenRefresh:
"""Tests for token refresh."""
@pytest.mark.asyncio
async def test_refresh_tokens_success(self, db, public_client, test_user):
"""Test refreshing tokens successfully."""
# Create initial tokens
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid profile",
)
refresh_token = result["refresh_token"]
# Refresh the tokens
new_result = await service.refresh_tokens(
db=db,
refresh_token=refresh_token,
client_id=public_client.client_id,
)
assert "access_token" in new_result
assert "refresh_token" in new_result
assert new_result["refresh_token"] != refresh_token # Token rotation
@pytest.mark.asyncio
async def test_refresh_tokens_invalid_token(self, db, public_client):
"""Test refreshing with invalid token fails."""
with pytest.raises(service.InvalidGrantError, match="Invalid refresh token"):
await service.refresh_tokens(
db=db,
refresh_token="invalid_token",
client_id=public_client.client_id,
)
@pytest.mark.asyncio
async def test_refresh_tokens_scope_reduction(self, db, public_client, test_user):
"""Test refreshing with reduced scope."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid profile email",
)
new_result = await service.refresh_tokens(
db=db,
refresh_token=result["refresh_token"],
client_id=public_client.client_id,
scope="openid", # Reduced scope
)
assert "openid" in new_result["scope"]
assert "profile" not in new_result["scope"]
@pytest.mark.asyncio
async def test_refresh_tokens_scope_expansion_fails(
self, db, public_client, test_user
):
"""Test refreshing with expanded scope fails."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid",
)
with pytest.raises(service.InvalidScopeError, match="Cannot expand scope"):
await service.refresh_tokens(
db=db,
refresh_token=result["refresh_token"],
client_id=public_client.client_id,
scope="openid profile", # Expanded scope
)
class TestTokenRevocation:
"""Tests for token revocation."""
@pytest.mark.asyncio
async def test_revoke_refresh_token(self, db, public_client, test_user):
"""Test revoking a refresh token."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid",
)
# Revoke the token
revoked = await service.revoke_token(
db=db,
token=result["refresh_token"],
token_type_hint="refresh_token",
)
assert revoked is True
# Try to use revoked token
with pytest.raises(service.InvalidGrantError, match="revoked"):
await service.refresh_tokens(
db=db,
refresh_token=result["refresh_token"],
client_id=public_client.client_id,
)
@pytest.mark.asyncio
async def test_revoke_all_user_tokens(self, db, public_client, test_user):
"""Test revoking all tokens for a user."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
# Create multiple tokens (we don't need to capture results)
await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid",
)
await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="profile",
)
# Revoke all
count = await service.revoke_all_user_tokens(db, test_user.id)
assert count == 2
class TestTokenIntrospection:
"""Tests for token introspection (RFC 7662)."""
@pytest.mark.asyncio
async def test_introspect_valid_access_token(self, db, public_client, test_user):
"""Test introspecting a valid access token."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid profile",
)
introspection = await service.introspect_token(
db=db,
token=result["access_token"],
)
assert introspection["active"] is True
assert introspection["client_id"] == public_client.client_id
assert introspection["sub"] == str(test_user.id)
@pytest.mark.asyncio
async def test_introspect_invalid_token(self, db):
"""Test introspecting an invalid token."""
introspection = await service.introspect_token(
db=db,
token="invalid_token",
)
assert introspection["active"] is False
class TestConsentManagement:
"""Tests for consent management."""
@pytest.mark.asyncio
async def test_grant_consent(self, db, public_client, test_user):
"""Test granting consent."""
consent = await service.grant_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
scopes=["openid", "profile"],
)
assert consent is not None
assert "openid" in consent.granted_scopes
assert "profile" in consent.granted_scopes
@pytest.mark.asyncio
async def test_check_consent_granted(self, db, public_client, test_user):
"""Test checking granted consent."""
await service.grant_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
scopes=["openid", "profile"],
)
has_consent = await service.check_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
requested_scopes=["openid"],
)
assert has_consent is True
@pytest.mark.asyncio
async def test_check_consent_not_granted(self, db, public_client, test_user):
"""Test checking consent that hasn't been granted."""
has_consent = await service.check_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
requested_scopes=["openid"],
)
assert has_consent is False
@pytest.mark.asyncio
async def test_revoke_consent(self, db, public_client, test_user):
"""Test revoking consent."""
await service.grant_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
scopes=["openid"],
)
revoked = await service.revoke_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
)
assert revoked is True
# Check consent is gone
has_consent = await service.check_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
requested_scopes=["openid"],
)
assert has_consent is False
class TestOAuthErrors:
"""Tests for OAuth error classes."""
def test_invalid_client_error(self):
"""Test InvalidClientError."""
error = service.InvalidClientError("Test description")
assert error.error == "invalid_client"
assert error.error_description == "Test description"
def test_invalid_grant_error(self):
"""Test InvalidGrantError."""
error = service.InvalidGrantError("Test description")
assert error.error == "invalid_grant"
assert error.error_description == "Test description"
def test_invalid_request_error(self):
"""Test InvalidRequestError."""
error = service.InvalidRequestError("Test description")
assert error.error == "invalid_request"
assert error.error_description == "Test description"
def test_invalid_scope_error(self):
"""Test InvalidScopeError."""
error = service.InvalidScopeError("Test description")
assert error.error == "invalid_scope"
assert error.error_description == "Test description"
def test_access_denied_error(self):
"""Test AccessDeniedError."""
error = service.AccessDeniedError("Test description")
assert error.error == "access_denied"
assert error.error_description == "Test description"