- Updated import statements and test logic to align with `repositories` naming changes. - Adjusted documentation and test names for consistency with the updated naming convention. - Improved test descriptions to reflect the repository-based structure.
114 lines
3.8 KiB
Python
114 lines
3.8 KiB
Python
# app/repositories/oauth_state.py
|
|
"""Repository for OAuthState model async database operations."""
|
|
|
|
import logging
|
|
from datetime import UTC, datetime
|
|
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.repository_exceptions import DuplicateEntryError
|
|
from app.models.oauth_state import OAuthState
|
|
from app.repositories.base import BaseRepository
|
|
from app.schemas.oauth import OAuthStateCreate
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmptySchema(BaseModel):
|
|
"""Placeholder schema for repository operations that don't need update schemas."""
|
|
|
|
|
|
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
|
|
"""Repository for OAuth state (CSRF protection)."""
|
|
|
|
async def create_state(
|
|
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
|
) -> OAuthState:
|
|
"""Create a new OAuth state for CSRF protection."""
|
|
try:
|
|
db_obj = OAuthState(
|
|
state=obj_in.state,
|
|
code_verifier=obj_in.code_verifier,
|
|
nonce=obj_in.nonce,
|
|
provider=obj_in.provider,
|
|
redirect_uri=obj_in.redirect_uri,
|
|
user_id=obj_in.user_id,
|
|
expires_at=obj_in.expires_at,
|
|
)
|
|
db.add(db_obj)
|
|
await db.commit()
|
|
await db.refresh(db_obj)
|
|
|
|
logger.debug("OAuth state created for %s", obj_in.provider)
|
|
return db_obj
|
|
except IntegrityError as e: # pragma: no cover
|
|
await db.rollback()
|
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
|
logger.error("OAuth state collision: %s", error_msg)
|
|
raise DuplicateEntryError("Failed to create OAuth state, please retry")
|
|
except Exception as e: # pragma: no cover
|
|
await db.rollback()
|
|
logger.exception("Error creating OAuth state: %s", e)
|
|
raise
|
|
|
|
async def get_and_consume_state(
|
|
self, db: AsyncSession, *, state: str
|
|
) -> OAuthState | None:
|
|
"""Get and delete OAuth state (consume it)."""
|
|
try:
|
|
result = await db.execute(
|
|
select(OAuthState).where(OAuthState.state == state)
|
|
)
|
|
db_obj = result.scalar_one_or_none()
|
|
|
|
if db_obj is None:
|
|
logger.warning("OAuth state not found: %s...", state[:8])
|
|
return None
|
|
|
|
now = datetime.now(UTC)
|
|
expires_at = db_obj.expires_at
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=UTC)
|
|
|
|
if expires_at < now:
|
|
logger.warning("OAuth state expired: %s...", state[:8])
|
|
await db.delete(db_obj)
|
|
await db.commit()
|
|
return None
|
|
|
|
await db.delete(db_obj)
|
|
await db.commit()
|
|
|
|
logger.debug("OAuth state consumed: %s...", state[:8])
|
|
return db_obj
|
|
except Exception as e: # pragma: no cover
|
|
await db.rollback()
|
|
logger.error("Error consuming OAuth state: %s", e)
|
|
raise
|
|
|
|
async def cleanup_expired(self, db: AsyncSession) -> int:
|
|
"""Clean up expired OAuth states."""
|
|
try:
|
|
now = datetime.now(UTC)
|
|
|
|
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
count = result.rowcount
|
|
if count > 0:
|
|
logger.info("Cleaned up %s expired OAuth states", count)
|
|
|
|
return count
|
|
except Exception as e: # pragma: no cover
|
|
await db.rollback()
|
|
logger.error("Error cleaning up expired OAuth states: %s", e)
|
|
raise
|
|
|
|
|
|
# Singleton instance
|
|
oauth_state_repo = OAuthStateRepository(OAuthState)
|