diff --git a/backend/app/alembic/versions/37315a5b4021_add_revokedtoken_model.py b/backend/app/alembic/versions/37315a5b4021_add_revokedtoken_model.py deleted file mode 100644 index be4318c..0000000 --- a/backend/app/alembic/versions/37315a5b4021_add_revokedtoken_model.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Add RevokedToken model - -Revision ID: 37315a5b4021 -Revises: 38bf9e7e74b3 -Create Date: 2025-02-28 17:11:07.741372 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = '37315a5b4021' -down_revision: Union[str, None] = '38bf9e7e74b3' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('revoked_tokens', - sa.Column('jti', sa.String(length=50), nullable=False), - sa.Column('token_type', sa.String(length=20), nullable=False), - sa.Column('user_id', sa.UUID(), nullable=True), - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) - op.create_index(op.f('ix_revoked_tokens_jti'), 'revoked_tokens', ['jti'], unique=True) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_revoked_tokens_jti'), table_name='revoked_tokens') - op.drop_table('revoked_tokens') - # ### end Alembic commands ### diff --git a/backend/app/api/dependencies.py b/backend/app/api/dependencies.py new file mode 100644 index 0000000..9ebae77 --- /dev/null +++ b/backend/app/api/dependencies.py @@ -0,0 +1,138 @@ +# app/api/dependencies/auth.py +from typing import Optional + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.orm import Session + +from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError +from app.core.database import get_db +from app.models.user import User + +# OAuth2 configuration +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + + +def get_current_user( + db: Session = Depends(get_db), + token: str = Depends(oauth2_scheme) +) -> User: + """ + Get the current authenticated user. + + Args: + db: Database session + token: JWT token from request + + Returns: + User: The authenticated user + + Raises: + HTTPException: If authentication fails + """ + try: + # Decode token and get user ID + token_data = get_token_data(token) + + # Get user from database + user = db.query(User).filter(User.id == token_data.user_id).first() + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Inactive user" + ) + + return user + + except TokenExpiredError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token expired", + headers={"WWW-Authenticate": "Bearer"} + ) + except TokenInvalidError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"} + ) + + +def get_current_active_user( + current_user: User = Depends(get_current_user) +) -> User: + """ + Check if the current user is active. + + Args: + current_user: The current authenticated user + + Returns: + User: The authenticated and active user + + Raises: + HTTPException: If user is inactive + """ + if not current_user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Inactive user" + ) + return current_user + + +def get_current_superuser( + current_user: User = Depends(get_current_user) +) -> User: + """ + Check if the current user is a superuser. + + Args: + current_user: The current authenticated user + + Returns: + User: The authenticated superuser + + Raises: + HTTPException: If user is not a superuser + """ + if not current_user.is_superuser: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not enough permissions" + ) + return current_user + + +def get_optional_current_user( + db: Session = Depends(get_db), + token: Optional[str] = Depends(oauth2_scheme) +) -> Optional[User]: + """ + Get the current user if authenticated, otherwise return None. + Useful for endpoints that work with both authenticated and unauthenticated users. + + Args: + db: Database session + token: JWT token from request + + Returns: + User or None: The authenticated user or None + """ + if not token: + return None + + try: + token_data = get_token_data(token) + user = db.query(User).filter(User.id == token_data.user_id).first() + if not user or not user.is_active: + return None + return user + except (TokenExpiredError, TokenInvalidError): + return None \ No newline at end of file diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index d8a5c54..af9233c 100644 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -1,134 +1,3 @@ -from typing import Any - -from app.auth.utils import revoke_token, is_token_revoked -from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.security import OAuth2PasswordRequestForm -from sqlalchemy.ext.asyncio import AsyncSession - -from app.auth.security import authenticate_user, create_access_token, create_refresh_token, decode_token -from app.core.database import get_db -from app.models.user import User -from app.schemas.token import TokenResponse, TokenPayload, RefreshToken -from app.schemas.user import UserResponse +from fastapi import APIRouter router = APIRouter() -oauth2_scheme = OAuth2PasswordRequestForm - - -# Existing: User Login Endpoint -@router.post( - "/auth/login", - response_model=TokenResponse, - summary="Authenticate user and provide tokens" -) -async def login( - form_data: OAuth2PasswordRequestForm = Depends(), - db: AsyncSession = Depends(get_db) -) -> Any: - """ - Authenticate a user with their credentials and return an access and refresh token. - """ - user = await authenticate_user(email=form_data.username, password=form_data.password, db=db) - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials.") - - # Generate access and refresh tokens - access_token = create_access_token({"sub": str(user.id), "type": "access"}) - refresh_token = create_refresh_token({"sub": str(user.id), "type": "refresh"}) - - return TokenResponse( - access_token=access_token, - refresh_token=refresh_token, - token_type="bearer", - expires_in=1800, # Example: 30 minutes for access token - user_id=str(user.id), - ) - - -# New: Logout Endpoint (Revoke Token) -@router.post( - "/auth/logout", - summary="Revoke the current token", - response_model=dict, - status_code=status.HTTP_200_OK -) -async def logout( - token: str = Depends(oauth2_scheme), - db: AsyncSession = Depends(get_db), - current_user: User = Depends( - lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db)) -): - """ - Logout the user by revoking the current token. - """ - # Decode the token and revoke it - payload: TokenPayload = await decode_token(token, db=db) - await revoke_token(payload.jti, payload.type, payload.sub, db) - - return {"message": "Successfully logged out."} - - -# New: Bulk Logout (Revoke All of a User's Tokens) -@router.post( - "/auth/logout-all", - summary="Revoke all active tokens for the user", - response_model=dict, - status_code=status.HTTP_200_OK -) -async def logout_all( - db: AsyncSession = Depends(get_db), - current_user: User = Depends( - lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db)) -): - """ - Revoke all tokens for the current user, effectively logging them out across all devices. - """ - await db.execute("DELETE FROM revoked_tokens WHERE user_id = :user_id", {"user_id": str(current_user.id)}) - await db.commit() - - return {"message": "Logged out from all devices."} - - -# Updated: Refresh Token Endpoint -@router.post( - "/auth/refresh-token", - response_model=TokenResponse, - summary="Generate a new access token using a refresh token" -) -async def refresh_token( - refresh_token: RefreshToken, - db: AsyncSession = Depends(get_db) -) -> TokenResponse: - """ - Refresh the user's access token using their refresh token while ensuring it has not been revoked. - """ - payload: TokenPayload = await decode_token(refresh_token.refresh_token, required_type="refresh", db=db) - - if await is_token_revoked(payload.jti, db): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Token has been revoked.") - - # Generate a new access token with the user's info - new_access_token = create_access_token({"sub": payload.sub, "type": "access"}) - return TokenResponse( - access_token=new_access_token, - refresh_token=refresh_token.refresh_token, # Reuse existing refresh token - expires_in=1800, # Example: 30 minutes expiry for access token - token_type="bearer", - user_id=payload.sub, - ) - - -# Existing: Get Current User Endpoint -@router.get( - "/auth/me", - response_model=UserResponse, - summary="Get user details from the token" -) -async def read_users_me( - current_user: User = Depends( - lambda token=Depends(oauth2_scheme), db=Depends(get_db): decode_token(token, db=db)) -) -> UserResponse: - """ - Retrieves the details of the currently authenticated user. - """ - return current_user diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py deleted file mode 100644 index a431bb4..0000000 --- a/backend/app/auth/dependencies.py +++ /dev/null @@ -1,40 +0,0 @@ -from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer -from jose import JWTError, jwt -from sqlalchemy.ext.asyncio import AsyncSession - -from app.core.database import get_db -from app.auth.security import decode_token -from app.models.user import User - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") - - -async def get_current_user( - token: str = Depends(oauth2_scheme), - db: AsyncSession = Depends(get_db) -): - try: - payload = await decode_token(token) # Use updated decode_token. - user_id: str = payload.sub - token_type: str = payload.type - - if user_id is None or token_type != "access": - raise HTTPException(status_code=401, detail="Invalid token type.") - - user = await db.get(User, user_id) - if user is None: - raise HTTPException(status_code=401, detail="User not found.") - - return user - except JWTError as e: - raise HTTPException(status_code=401, detail=str(e)) - - - -async def get_current_active_user( - current_user: User = Depends(get_current_user), -): - if not current_user.is_active: - raise HTTPException(status_code=400, detail="Inactive user") - return current_user diff --git a/backend/app/auth/security.py b/backend/app/auth/security.py deleted file mode 100644 index 69f3f77..0000000 --- a/backend/app/auth/security.py +++ /dev/null @@ -1,176 +0,0 @@ -from datetime import datetime, timedelta, timezone -from typing import Optional -from uuid import uuid4 - -from fastapi import Depends -from jose import jwt, JWTError, ExpiredSignatureError, JOSEError -from passlib.context import CryptContext -from sqlalchemy.ext.asyncio import AsyncSession - -from app.core.config import settings -from app.core.database import get_db -from app.schemas.token import TokenPayload, TokenResponse -from auth.utils import is_token_revoked - -# Configuration -SECRET_KEY = settings.SECRET_KEY -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 -REFRESH_TOKEN_EXPIRE_DAYS = 7 - -# Password hashing context -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - - -def verify_password(plain_password: str, hashed_password: str) -> bool: - """Verify a plain password against its hash.""" - return pwd_context.verify(plain_password, hashed_password) - - -def get_password_hash(password: str) -> str: - """Generate password hash.""" - return pwd_context.hash(password) - - -def create_tokens(user_id: str) -> TokenResponse: - """ - Create both access and refresh tokens for a user. - - Args: - user_id: The user's ID - - Returns: - TokenResponse containing both tokens and metadata - """ - # Add `jti` during token creation - access_token = create_access_token({"sub": user_id, "jti": str(uuid4())}) - refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())}) - - return TokenResponse( - access_token=access_token, - refresh_token=refresh_token, - token_type="bearer", - expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60, - user_id=user_id, - scope="read write" - ) - - -def create_token( - data: dict, - expires_delta: Optional[timedelta] = None, - token_type: str = "access" -) -> str: - """Create a JWT token with the specified type and expiration.""" - to_encode = data.copy() - - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + ( - timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access" - else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) - ) - - to_encode.update({ - "exp": expire, - "type": token_type, - "iat": datetime.now(timezone.utc), - }) - if "jti" not in to_encode: - to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added - - return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - - -def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: - """Create a new access token.""" - # Ensure `data` includes `jti` for consistency - if "jti" not in data: - data["jti"] = str(uuid4()) - return create_token(data, expires_delta, "access") - - -def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: - """Create a new refresh token.""" - # Ensure `data` includes `jti` for consistency - if "jti" not in data: - data["jti"] = str(uuid4()) - return create_token(data, expires_delta, "refresh") - - -async def decode_token( - token: str, - required_type: str = "access", - db: AsyncSession = Depends(get_db) -) -> TokenPayload: - """ - Decode and validate a JWT token, including revocation checks. - - Args: - token (str): The JWT token to decode. - required_type (str): The expected token type (default: "access"). - db (AsyncSession): Database session for token revocation checks. - - Returns: - TokenPayload: The decoded token data. - - Raises: - JWTError: If the token is expired, revoked, malformed, or fails validation. - """ - try: - # Step 1: Decode the JWT token - payload = jwt.decode( - token, - SECRET_KEY, - algorithms=[ALGORITHM], - options={ - "verify_exp": True, - "verify_iat": True, - "require": ["exp", "iat", "sub", "type", "jti"] - } - ) - - except ExpiredSignatureError: - raise JWTError("Token has expired. Please refresh your token or login again.") - except JWTError as e: - if "Signature verification failed" in str(e): - raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.") - raise JWTError(f"Failed to decode the token: {e}") - except JOSEError as e: - if "segments" in str(e).lower(): - raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).") - raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e - except Exception as e: - # Catch-all for unexpected exceptions during decoding - raise JWTError(f"An unexpected error occurred while decoding the token: {e}") - - # Step 2: Validate Required Claims - required_claims = ["exp", "sub", "type", "jti"] - missing_claims = [claim for claim in required_claims if claim not in payload] - if missing_claims: - raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.") - - # Step 3: Validate Expiry - expiration = datetime.fromtimestamp(payload["exp"]) - if datetime.now(timezone.utc) > expiration: - raise JWTError("Token has expired. Please refresh your token or login again.") - - # Step 4: Validate Token Type - token_type = payload.get("type") - if token_type != required_type: - raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.") - - # Step 5: Check Revocation - jti = payload.get("jti") - if await is_token_revoked(jti, db): - raise JWTError("Token has been revoked. Please login again to generate a new token.") - - # Step 6: Return Validated Token Payload - return TokenPayload( - sub=payload["sub"], - type=payload["type"], - exp=expiration, - iat=datetime.fromtimestamp(payload.get("iat", 0)), - jti=jti - ) diff --git a/backend/app/auth/utils.py b/backend/app/auth/utils.py deleted file mode 100644 index 350b4e4..0000000 --- a/backend/app/auth/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -from datetime import datetime, timezone, timedelta - -from sqlalchemy import delete -from sqlalchemy.ext.asyncio import AsyncSession -from app.models.token import RevokedToken - - -async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession): - """Revoke a token by storing its `jti` in the revoked_tokens table.""" - revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id) - db.add(revoked_token) - await db.commit() - - -async def is_token_revoked(jti: str, db: AsyncSession) -> bool: - """Check whether the token's JTI is in the revoked_tokens table.""" - from sqlalchemy import select - result = await db.execute(select(RevokedToken).where(RevokedToken.jti == jti)) - revoked = result.scalar_one_or_none() - return revoked is not None - - -async def cleanup_expired_tokens(db: AsyncSession): - """Delete revoked tokens that are past their expiration time.""" - now = datetime.now(timezone.utc) - - # For access tokens (shorter expiry) - expire_before = now - timedelta(days=1) # Keep for 1 day past expiry - await db.execute( - delete(RevokedToken).where( - (RevokedToken.token_type == "access") & - (RevokedToken.created_at < expire_before) - ) - ) - - # For refresh tokens (longer expiry) - expire_before = now - timedelta(days=14) # Keep for 14 days past expiry - await db.execute( - delete(RevokedToken).where( - (RevokedToken.token_type == "refresh") & - (RevokedToken.created_at < expire_before) - ) - ) - - await db.commit() \ No newline at end of file diff --git a/backend/app/core/auth.py b/backend/app/core/auth.py new file mode 100644 index 0000000..d2f6ca1 --- /dev/null +++ b/backend/app/core/auth.py @@ -0,0 +1,183 @@ +# app/core/auth.py +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional, Union +import uuid + +from jose import jwt, JWTError +from passlib.context import CryptContext +from pydantic import ValidationError + +from app.core.config import settings +from app.schemas.users import TokenData, TokenPayload + + +# Password hashing context +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# Custom exceptions for auth +class AuthError(Exception): + """Base authentication error""" + pass + +class TokenExpiredError(AuthError): + """Token has expired""" + pass + +class TokenInvalidError(AuthError): + """Token is invalid""" + pass + +class TokenMissingClaimError(AuthError): + """Token is missing a required claim""" + pass + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against a hash.""" + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password: str) -> str: + """Generate a password hash.""" + return pwd_context.hash(password) + + +def create_access_token( + subject: Union[str, Any], + expires_delta: Optional[timedelta] = None, + claims: Optional[Dict[str, Any]] = None +) -> str: + """ + Create a JWT access token. + + Args: + subject: The subject of the token (usually user_id) + expires_delta: Optional expiration time delta + claims: Optional additional claims to include in the token + + Returns: + Encoded JWT token + """ + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + + # Base token data + to_encode = { + "sub": str(subject), + "exp": expire, + "iat": datetime.now(tz=timezone.utc), + "jti": str(uuid.uuid4()), + "type": "access" + } + + # Add custom claims + if claims: + to_encode.update(claims) + + # Create the JWT + encoded_jwt = jwt.encode( + to_encode, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + return encoded_jwt + + +def create_refresh_token( + subject: Union[str, Any], + expires_delta: Optional[timedelta] = None +) -> str: + """ + Create a JWT refresh token. + + Args: + subject: The subject of the token (usually user_id) + expires_delta: Optional expiration time delta + + Returns: + Encoded JWT refresh token + """ + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + + to_encode = { + "sub": str(subject), + "exp": expire, + "iat": datetime.now(timezone.utc), + "jti": str(uuid.uuid4()), + "type": "refresh" + } + + encoded_jwt = jwt.encode( + to_encode, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + return encoded_jwt + + +def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload: + """ + Decode and verify a JWT token. + + Args: + token: JWT token to decode + verify_type: Optional token type to verify (access or refresh) + + Returns: + TokenPayload: The decoded token data + + Raises: + TokenExpiredError: If the token has expired + TokenInvalidError: If the token is invalid + TokenMissingClaimError: If a required claim is missing + """ + try: + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM] + ) + + # Check required claims before Pydantic validation + if not payload.get("sub"): + raise TokenMissingClaimError("Token missing 'sub' claim") + + # Verify token type if specified + if verify_type and payload.get("type") != verify_type: + raise TokenInvalidError(f"Invalid token type, expected {verify_type}") + + # Now validate with Pydantic + token_data = TokenPayload(**payload) + return token_data + + except JWTError as e: + # Check if the error is due to an expired token + if "expired" in str(e).lower(): + raise TokenExpiredError("Token has expired") + raise TokenInvalidError("Invalid authentication token") + except ValidationError: + raise TokenInvalidError("Invalid token payload") + + +def get_token_data(token: str) -> TokenData: + """ + Extract the user ID and superuser status from a token. + + Args: + token: JWT token + + Returns: + TokenData with user_id and is_superuser + """ + payload = decode_token(token) + user_id = payload.sub + is_superuser = payload.is_superuser or False + + return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser) \ No newline at end of file diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 07d56d6..994a404 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -14,6 +14,7 @@ class Settings(BaseSettings): POSTGRES_PORT: str = "5432" POSTGRES_DB: str = "eventspace" DATABASE_URL: Optional[str] = None + REFRESH_TOKEN_EXPIRE_DAYS: int = 60 db_pool_size: int = 20 # Default connection pool size db_max_overflow: int = 50 # Maximum overflow connections db_pool_timeout: int = 30 # Seconds to wait for a connection @@ -24,6 +25,7 @@ class Settings(BaseSettings): sql_echo_pool: bool = False # Log connection pool events sql_echo_timing: bool = False # Log query execution times slow_query_threshold: float = 0.5 # Log queries taking longer than this + @property def database_url(self) -> str: """ diff --git a/backend/app/main.py b/backend/app/main.py index 697c8fc..dbc40a9 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,15 +1,12 @@ +import logging + from apscheduler.schedulers.asyncio import AsyncIOScheduler -from apscheduler.triggers.cron import CronTrigger from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse -from app.core.config import settings from app.api.main import api_router -import logging - -from auth.utils import cleanup_expired_tokens -from app.core.database import SessionLocal +from app.core.config import settings scheduler = AsyncIOScheduler() @@ -32,26 +29,6 @@ app.add_middleware( ) -# Create a function that gets its own database session -async def scheduled_cleanup(): - async with SessionLocal() as db: - await cleanup_expired_tokens(db) - -@app.on_event("startup") -async def start_scheduler(): - # Run every day at 3 AM - scheduler.add_job( - scheduled_cleanup, - CronTrigger(hour=10, minute=0), - id="token_cleanup", - name="Clean up expired revoked tokens" - ) - scheduler.start() - -@app.on_event("shutdown") -async def stop_scheduler(): - scheduler.shutdown() - @app.get("/", response_class=HTMLResponse) async def root(): return """ @@ -67,4 +44,5 @@ async def root(): """ -app.include_router(api_router, prefix=settings.API_V1_STR) \ No newline at end of file + +app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 583b43f..12ee72e 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -29,7 +29,6 @@ from .gift import ( from .email_template import EmailTemplate, TemplateType from .notification_log import NotificationLog, NotificationType, NotificationStatus from .activity_log import ActivityLog, ActivityType -from .token import RevokedToken # Make sure all models are imported above this line __all__ = [ 'Base', 'TimestampMixin', 'UUIDMixin', @@ -40,5 +39,4 @@ __all__ = [ 'EmailTemplate', 'TemplateType', 'NotificationLog', 'NotificationType', 'NotificationStatus', 'ActivityLog', 'ActivityType', - 'RevokedToken', ] \ No newline at end of file diff --git a/backend/app/models/token.py b/backend/app/models/token.py deleted file mode 100644 index e2bb3cb..0000000 --- a/backend/app/models/token.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqlalchemy import Column, String, ForeignKey -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship -from app.models.base import Base, TimestampMixin, UUIDMixin - - -class RevokedToken(UUIDMixin, TimestampMixin, Base): - """Model to store revoked JWT tokens via their jti (JWT ID).""" - __tablename__ = "revoked_tokens" - - jti = Column(String(length=50), nullable=False, unique=True, index=True) - token_type = Column(String(length=20), nullable=False) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")) - - user = relationship("User", back_populates="revoked_tokens") \ No newline at end of file diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 54b3c31..7fe547f 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -25,7 +25,6 @@ class User(Base, UUIDMixin, TimestampMixin): foreign_keys="EventManager.user_id" ) guest_profiles = relationship("Guest", back_populates="user", foreign_keys="Guest.user_id") - revoked_tokens = relationship("RevokedToken", back_populates="user", cascade="all, delete") def __repr__(self): return f"" \ No newline at end of file diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py deleted file mode 100644 index f0ef04b..0000000 --- a/backend/app/schemas/user.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Optional -from uuid import UUID -from pydantic import BaseModel, EmailStr, Field, field_validator -from datetime import datetime -from passlib.hash import bcrypt - - -# Base schema with shared user attributes -class UserBase(BaseModel): - """Base schema with common user attributes.""" - email: EmailStr - first_name: str - last_name: str - - -# Schema for creating a new user -class UserCreate(UserBase): - """Schema for user registration.""" - password: str = Field( - ..., - min_length=8, - description="Password must be at least 8 characters" - ) - - @field_validator('password') - def password_strength(cls, v): - # Add more complex password validation if needed - if len(v) < 8: - raise ValueError('Password must be at least 8 characters') - return v - - def hash_password(self) -> str: - """Hash the password before saving it to the database.""" - return bcrypt.hash(self.password) - - -# Schema for updating user details -class UserUpdate(BaseModel): - """Schema for updating user information.""" - email: Optional[EmailStr] = None - first_name: Optional[str] = None - last_name: Optional[str] = None - phone_number: Optional[str] = None - is_active: Optional[bool] = None - preferences: Optional[dict] = None # Provide preferences support - - -# Schema for user responses (read-only fields) -class UserResponse(UserBase): - """Schema for user responses in API.""" - id: UUID - is_active: bool - is_superuser: bool # Include roles or superuser flags if needed - preferences: Optional[dict] = None # Include preferences in response - created_at: datetime - updated_at: Optional[datetime] = None - - class Config: - orm_mode = True # Enable mapping SQLAlchemy models to this schema - - -# Schema for user authentication (e.g., login requests) -class UserAuth(BaseModel): - """Schema for user authentication.""" - email: EmailStr - password: str \ No newline at end of file diff --git a/backend/app/schemas/users.py b/backend/app/schemas/users.py new file mode 100644 index 0000000..ec7c1f1 --- /dev/null +++ b/backend/app/schemas/users.py @@ -0,0 +1,126 @@ +# app/schemas/users.py +import re +from datetime import datetime +from typing import Optional, Dict, Any +from uuid import UUID + +import pydantic +from pydantic import BaseModel, EmailStr, field_validator, ConfigDict + + +class UserBase(BaseModel): + email: EmailStr + first_name: str + last_name: str + phone_number: Optional[str] = None + + @field_validator('phone_number') + @classmethod + def validate_phone_number(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + # Simple regex for phone validation + if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v): + raise ValueError('Invalid phone number format') + return v + + +class UserCreate(UserBase): + password: str + + @field_validator('password') + @classmethod + def password_strength(cls, v: str) -> str: + """Basic password strength validation""" + if len(v) < 8: + raise ValueError('Password must be at least 8 characters') + if not any(char.isdigit() for char in v): + raise ValueError('Password must contain at least one digit') + if not any(char.isupper() for char in v): + raise ValueError('Password must contain at least one uppercase letter') + return v + + +class UserUpdate(BaseModel): + first_name: Optional[str] = None + last_name: Optional[str] = None + phone_number: Optional[str] = None + preferences: Optional[Dict[str, Any]] = None + + @field_validator('phone_number') + @classmethod + def validate_phone_number(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + # Simple regex for phone validation + if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v): + raise ValueError('Invalid phone number format') + return v + + +class UserInDB(UserBase): + id: UUID + is_active: bool + is_superuser: bool + created_at: datetime + updated_at: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + +class UserResponse(UserBase): + id: UUID + is_active: bool + is_superuser: bool + created_at: datetime + updated_at: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + +class Token(BaseModel): + access_token: str + refresh_token: Optional[str] = None + token_type: str = "bearer" + + +class TokenPayload(BaseModel): + sub: str # User ID + exp: int # Expiration time + iat: Optional[int] = None # Issued at + jti: Optional[str] = None # JWT ID + is_superuser: Optional[bool] = False + first_name: Optional[str] = None + email: Optional[str] = None + type: Optional[str] = None # Token type (access/refresh) + + +class TokenData(BaseModel): + user_id: UUID + is_superuser: bool = False + + +class PasswordReset(BaseModel): + token: str + new_password: str + + @field_validator('new_password') + @classmethod + def password_strength(cls, v: str) -> str: + """Basic password strength validation""" + if len(v) < 8: + raise ValueError('Password must be at least 8 characters') + if not any(char.isdigit() for char in v): + raise ValueError('Password must contain at least one digit') + if not any(char.isupper() for char in v): + raise ValueError('Password must contain at least one uppercase letter') + return v + + +class LoginRequest(BaseModel): + email: EmailStr + password: str + + +class RefreshTokenRequest(BaseModel): + refresh_token: str \ No newline at end of file diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py new file mode 100644 index 0000000..4941671 --- /dev/null +++ b/backend/app/services/auth_service.py @@ -0,0 +1,193 @@ +# app/services/auth_service.py +import logging +from typing import Optional +from uuid import UUID + +from sqlalchemy.orm import Session + +from app.core.auth import ( + verify_password, + get_password_hash, + create_access_token, + create_refresh_token, + TokenExpiredError, + TokenInvalidError +) +from app.models.user import User +from app.schemas.users import Token, UserCreate + +logger = logging.getLogger(__name__) + + +class AuthenticationError(Exception): + """Exception raised for authentication errors""" + pass + + +class AuthService: + """Service for handling authentication operations""" + + @staticmethod + def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: + """ + Authenticate a user with email and password. + + Args: + db: Database session + email: User email + password: User password + + Returns: + User if authenticated, None otherwise + """ + user = db.query(User).filter(User.email == email).first() + + if not user: + return None + + if not verify_password(password, user.password_hash): + return None + + if not user.is_active: + raise AuthenticationError("User account is inactive") + + return user + + @staticmethod + def create_user(db: Session, user_data: UserCreate) -> User: + """ + Create a new user. + + Args: + db: Database session + user_data: User data + + Returns: + Created user + """ + # Check if user already exists + existing_user = db.query(User).filter(User.email == user_data.email).first() + if existing_user: + raise AuthenticationError("User with this email already exists") + + # Create new user + hashed_password = get_password_hash(user_data.password) + + # Create user object from model + user = User( + email=user_data.email, + password_hash=hashed_password, + first_name=user_data.first_name, + last_name=user_data.last_name, + phone_number=user_data.phone_number, + is_active=True, + is_superuser=False + ) + + db.add(user) + db.commit() + db.refresh(user) + + return user + + @staticmethod + def create_tokens(user: User) -> Token: + """ + Create access and refresh tokens for a user. + + Args: + user: User to create tokens for + + Returns: + Token object with access and refresh tokens + """ + # Generate claims + claims = { + "is_superuser": user.is_superuser, + "email": user.email, + "first_name": user.first_name + } + + # Create tokens + access_token = create_access_token( + subject=str(user.id), + claims=claims + ) + + refresh_token = create_refresh_token( + subject=str(user.id) + ) + + return Token( + access_token=access_token, + refresh_token=refresh_token + ) + + @staticmethod + def refresh_tokens(db: Session, refresh_token: str) -> Token: + """ + Generate new tokens using a refresh token. + + Args: + db: Database session + refresh_token: Valid refresh token + + Returns: + New access and refresh tokens + + Raises: + TokenExpiredError: If refresh token has expired + TokenInvalidError: If refresh token is invalid + """ + from app.core.auth import decode_token, get_token_data + + try: + # Verify token is a refresh token + decode_token(refresh_token, verify_type="refresh") + + # Get user ID from token + token_data = get_token_data(refresh_token) + user_id = token_data.user_id + + # Get user from database + user = db.query(User).filter(User.id == user_id).first() + if not user or not user.is_active: + raise TokenInvalidError("Invalid user or inactive account") + + # Generate new tokens + return AuthService.create_tokens(user) + + except (TokenExpiredError, TokenInvalidError) as e: + logger.warning(f"Token refresh failed: {str(e)}") + raise + + @staticmethod + def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool: + """ + Change a user's password. + + Args: + db: Database session + user_id: User ID + current_password: Current password + new_password: New password + + Returns: + True if password was changed successfully + + Raises: + AuthenticationError: If current password is incorrect + """ + user = db.query(User).filter(User.id == user_id).first() + if not user: + raise AuthenticationError("User not found") + + # Verify current password + if not verify_password(current_password, user.password_hash): + raise AuthenticationError("Current password is incorrect") + + # Update password + user.password_hash = get_password_hash(new_password) + db.commit() + + return True diff --git a/backend/app/utils/test_utils.py b/backend/app/utils/test_utils.py index 0edc59e..26598b6 100644 --- a/backend/app/utils/test_utils.py +++ b/backend/app/utils/test_utils.py @@ -1,5 +1,6 @@ import logging from sqlalchemy import create_engine, event +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker, clear_mappers from sqlalchemy.pool import StaticPool @@ -42,4 +43,37 @@ def teardown_test_db(engine): Base.metadata.drop_all(engine) # Dispose of engine - engine.dispose() \ No newline at end of file + engine.dispose() + +async def get_async_test_engine(): + """Create an async SQLite in-memory engine specifically for testing""" + test_engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, # Use static pool for in-memory testing + echo=False + ) + return test_engine + + +async def setup_async_test_db(): + """Create an async test database and session factory""" + test_engine = await get_async_test_engine() + + async with test_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + AsyncTestingSessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=test_engine, + expire_on_commit=False, + class_=AsyncSession + ) + + return test_engine, AsyncTestingSessionLocal + + +async def teardown_async_test_db(engine): + """Clean up after async tests""" + await engine.dispose() diff --git a/backend/requirements.txt b/backend/requirements.txt index cbc5a8c..ecc4682 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,7 +11,7 @@ sqlalchemy>=2.0.29 alembic>=1.14.1 psycopg2-binary>=2.9.9 asyncpg>=0.29.0 - +aiosqlite==0.21.0 # Security and authentication python-jose>=3.4.0 passlib>=1.7.4 diff --git a/backend/tests/auth/dependencies.py b/backend/tests/auth/dependencies.py deleted file mode 100644 index faccd6f..0000000 --- a/backend/tests/auth/dependencies.py +++ /dev/null @@ -1,85 +0,0 @@ -from datetime import datetime, timezone -from unittest.mock import AsyncMock - -import pytest -from fastapi import HTTPException -from jose import jwt - -from app.auth.dependencies import get_current_user, get_current_active_user -from app.auth.security import SECRET_KEY, ALGORITHM -from app.models.user import User - - -@pytest.fixture -def mock_user(): - return User( - id="123e4567-e89b-12d3-a456-426614174000", - email="test@example.com", - password_hash="hashedpassword", - is_active=True - ) - - -@pytest.mark.asyncio -async def test_get_current_user_success(mock_user): - # Create a valid access token with required claims - valid_token = jwt.encode( - {"sub": str(mock_user.id), "type": "access", "exp": datetime.now(tz=timezone.utc).timestamp() + 3600}, - SECRET_KEY, - algorithm=ALGORITHM - ) - - # Mock database session - mock_db = AsyncMock() - mock_db.get.return_value = mock_user # Ensure `db.get()` returns the mock_user - - # Call the dependency - user = await get_current_user(token=valid_token, db=mock_db) - - # Assertions - assert user == mock_user, "Returned user does not match the mocked user." - mock_db.get.assert_called_once_with(User, mock_user.id) - - -@pytest.mark.asyncio -async def test_get_current_user_invalid_token(): - invalid_token = "invalid.token.payload" - - with pytest.raises(HTTPException) as exc_info: - await get_current_user(token=invalid_token, db=AsyncMock()) - - assert exc_info.value.status_code == 401 - assert exc_info.value.detail == "Could not validate credentials" - - -@pytest.mark.asyncio -async def test_get_current_user_wrong_token_type(): - token = jwt.encode({"sub": "123", "type": "refresh"}, SECRET_KEY, algorithm=ALGORITHM) - - with pytest.raises(HTTPException) as exc_info: - await get_current_user(token=token, db=AsyncMock()) - - assert exc_info.value.status_code == 401 - assert exc_info.value.detail == "Could not validate credentials" - - -@pytest.mark.asyncio -async def test_get_current_active_user_success(mock_user): - result = await get_current_active_user(mock_user) - assert result == mock_user - - -@pytest.mark.asyncio -async def test_get_current_active_user_inactive(): - inactive_user = User( - id="123e4567-e89b-12d3-a456-426614174000", - email="inactive@example.com", - password_hash="hashedpassword", - is_active=False - ) - - with pytest.raises(HTTPException) as exc_info: - await get_current_active_user(inactive_user) - - assert exc_info.value.status_code == 400 - assert exc_info.value.detail == "Inactive user" diff --git a/backend/tests/auth/test_security.py b/backend/tests/auth/test_security.py deleted file mode 100644 index 14cf8b8..0000000 --- a/backend/tests/auth/test_security.py +++ /dev/null @@ -1,147 +0,0 @@ -from datetime import timedelta, datetime -from unittest.mock import AsyncMock - -import pytest -from jose import jwt, JWTError -from sqlalchemy.ext.asyncio import AsyncSession - -from app.auth.security import ( - get_password_hash, verify_password, - create_access_token, create_refresh_token, - decode_token, SECRET_KEY, ALGORITHM -) -from app.schemas.token import TokenPayload - - -def test_password_hashing(): - plain_password = "securepassword123" - hashed_password = get_password_hash(plain_password) - - # Ensure hashed passwords are not the same - assert hashed_password != plain_password - # Test password verification - assert verify_password(plain_password, hashed_password) - assert not verify_password("wrongpassword", hashed_password) - - -def test_access_token_creation(): - user_id = "123e4567-e89b-12d3-a456-426614174000" - token = create_access_token({"sub": user_id}) - decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - - assert decoded_payload.get("sub") == user_id - assert decoded_payload.get("type") == "access" - - -def test_refresh_token_creation(): - user_id = "123e4567-e89b-12d3-a456-426614174000" - token = create_refresh_token({"sub": user_id}) - decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - - assert decoded_payload.get("sub") == user_id - assert decoded_payload.get("type") == "refresh" - - -@pytest.mark.asyncio -async def test_decode_token_valid(): - user_id = "123e4567-e89b-12d3-a456-426614174000" - access_token = create_access_token({"sub": user_id, "jti": "test-jti"}) - - # Mock is_token_revoked to return False - mock_db = AsyncMock(spec=AsyncSession) - mock_db.get = AsyncMock(return_value=None) - - token_payload = await decode_token(access_token, db=mock_db) - - assert isinstance(token_payload, TokenPayload) - assert token_payload.sub == user_id - assert token_payload.type == "access" - - - -@pytest.mark.asyncio -async def test_decode_token_expired(): - user_id = "123e4567-e89b-12d3-a456-426614174000" - token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1)) - - # Mock database - mock_db = AsyncMock(spec=AsyncSession) - - with pytest.raises(JWTError) as exc_info: - await decode_token(token, db=mock_db) - - assert str(exc_info.value) == "Token has been revoked." - - - -@pytest.mark.asyncio -async def test_decode_token_missing_exp(): - # Create a token without the `exp` claim - payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"} - token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) - - # Mock database - mock_db = AsyncMock(spec=AsyncSession) - - with pytest.raises(JWTError) as exc_info: - await decode_token(token, db=mock_db) - - assert str(exc_info.value) == "Malformed token. Missing required claim(s)." - - - -@pytest.mark.asyncio -async def test_decode_token_missing_sub(): - # Create a token without the `sub` claim - payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"} - token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) - - # Mock database - mock_db = AsyncMock(spec=AsyncSession) - - with pytest.raises(JWTError) as exc_info: - await decode_token(token, db=mock_db) - - assert str(exc_info.value) == "Malformed token. Missing required claim(s)." - - -@pytest.mark.asyncio -async def test_decode_token_invalid_signature(): - # Use a different secret key for signing - token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM) - - # Mock database - mock_db = AsyncMock(spec=AsyncSession) - - with pytest.raises(JWTError) as exc_info: - await decode_token(token, db=mock_db) - - assert str(exc_info.value) == "Signature verification failed." - - - -@pytest.mark.asyncio -async def test_decode_token_malformed(): - malformed_token = "malformed.header.payload" - - # Mock database - mock_db = AsyncMock(spec=AsyncSession) - - with pytest.raises(JWTError) as exc_info: - await decode_token(malformed_token, db=mock_db) - - assert str(exc_info.value) == "Invalid token." - - -@pytest.mark.asyncio -async def test_decode_token_invalid_type(): - user_id = "123e4567-e89b-12d3-a456-426614174000" - token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh" - - # Mock database - mock_db = AsyncMock(spec=AsyncSession) - - with pytest.raises(JWTError) as exc_info: - await decode_token(token, required_type="access", db=mock_db) # Expecting an access token - - assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'." diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 4510886..5e25047 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -9,7 +9,7 @@ from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory, EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \ NotificationType, NotificationStatus from app.models.user import User -from app.utils.test_utils import setup_test_db, teardown_test_db +from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db @pytest.fixture(scope="function") @@ -30,6 +30,15 @@ def db_session(): teardown_test_db(test_engine) +@pytest.fixture(scope="function") # Define a fixture +async def async_test_db(): + """Fixture provides new testing engine and session for each test run to improve isolation.""" + + test_engine, AsyncTestingSessionLocal = await setup_async_test_db() + yield test_engine, AsyncTestingSessionLocal + await teardown_async_test_db(test_engine) + + @pytest.fixture def mock_user(db_session): """Fixture to create and return a mock User instance.""" @@ -72,7 +81,6 @@ def event_fixture(db_session, mock_user): return event - @pytest.fixture def gift_item_fixture(db_session, mock_user): """ diff --git a/backend/tests/core/__init__.py b/backend/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/core/test_auth.py b/backend/tests/core/test_auth.py new file mode 100644 index 0000000..7929e11 --- /dev/null +++ b/backend/tests/core/test_auth.py @@ -0,0 +1,260 @@ +# tests/core/test_auth.py +import uuid +import pytest +from datetime import datetime, timedelta, timezone +from jose import jwt +from pydantic import ValidationError + +from app.core.auth import ( + verify_password, + get_password_hash, + create_access_token, + create_refresh_token, + decode_token, + get_token_data, + TokenExpiredError, + TokenInvalidError, + TokenMissingClaimError +) +from app.core.config import settings + + +class TestPasswordHandling: + """Tests for password hashing and verification functions""" + + def test_password_hash_different_from_password(self): + """Test that a password hash is different from the original password""" + password = "TestPassword123" + hashed = get_password_hash(password) + assert hashed != password + + def test_verify_correct_password(self): + """Test that verify_password returns True for the correct password""" + password = "TestPassword123" + hashed = get_password_hash(password) + assert verify_password(password, hashed) is True + + def test_verify_incorrect_password(self): + """Test that verify_password returns False for an incorrect password""" + password = "TestPassword123" + wrong_password = "WrongPassword123" + hashed = get_password_hash(password) + assert verify_password(wrong_password, hashed) is False + + def test_same_password_different_hash(self): + """Test that the same password gets a different hash each time""" + password = "TestPassword123" + hash1 = get_password_hash(password) + hash2 = get_password_hash(password) + assert hash1 != hash2 + + +class TestTokenCreation: + """Tests for token creation functions""" + + def test_create_access_token(self): + """Test that an access token is created with the correct claims""" + user_id = str(uuid.uuid4()) + custom_claims = { + "email": "test@example.com", + "first_name": "Test", + "is_superuser": True + } + token = create_access_token(subject=user_id, claims=custom_claims) + + # Decode token to verify claims + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM] + ) + + # Check standard claims + assert payload["sub"] == user_id + assert "jti" in payload + assert "exp" in payload + assert "iat" in payload + assert payload["type"] == "access" + + # Check custom claims + for key, value in custom_claims.items(): + assert payload[key] == value + + def test_create_refresh_token(self): + """Test that a refresh token is created with the correct claims""" + user_id = str(uuid.uuid4()) + token = create_refresh_token(subject=user_id) + + # Decode token to verify claims + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM] + ) + + # Check standard claims + assert payload["sub"] == user_id + assert "jti" in payload + assert "exp" in payload + assert "iat" in payload + assert payload["type"] == "refresh" + + def test_token_expiration(self): + """Test that tokens have the correct expiration time""" + user_id = str(uuid.uuid4()) + expires = timedelta(minutes=5) + + # Create token with specific expiration + token = create_access_token( + subject=user_id, + expires_delta=expires + ) + + # Decode token + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM] + ) + + # Get actual expiration time from token + expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + + # Calculate expected expiration (approximately) + now = datetime.now(timezone.utc) + expected_expiration = now + expires + + # Difference should be small (less than 1 second) + difference = abs((expiration - expected_expiration).total_seconds()) + assert difference < 1 + + +class TestTokenDecoding: + """Tests for token decoding and validation functions""" + + def test_decode_valid_token(self): + """Test that a valid token can be decoded""" + user_id = str(uuid.uuid4()) + token = create_access_token(subject=user_id) + + # Decode token + payload = decode_token(token) + + # Check that the subject matches + assert payload.sub == user_id + + def test_decode_expired_token(self): + """Test that an expired token raises TokenExpiredError""" + user_id = str(uuid.uuid4()) + + # Create a token that's already expired by directly manipulating the payload + now = datetime.now(timezone.utc) + expired_time = now - timedelta(hours=1) # 1 hour in the past + + # Create the expired token manually + payload = { + "sub": user_id, + "exp": int(expired_time.timestamp()), # Set expiration in the past + "iat": int(now.timestamp()), + "jti": str(uuid.uuid4()), + "type": "access" + } + + expired_token = jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + # Attempting to decode should raise TokenExpiredError + with pytest.raises(TokenExpiredError): + decode_token(expired_token) + + def test_decode_invalid_token(self): + """Test that an invalid token raises TokenInvalidError""" + invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature" + + with pytest.raises(TokenInvalidError): + decode_token(invalid_token) + + def test_decode_token_with_missing_sub(self): + """Test that a token without 'sub' claim raises TokenMissingClaimError""" + # Create a token without a subject + now = datetime.now(timezone.utc) + payload = { + "exp": int((now + timedelta(minutes=30)).timestamp()), + "iat": int(now.timestamp()), + "jti": str(uuid.uuid4()), + "type": "access" + # No 'sub' claim + } + + token = jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + with pytest.raises(TokenMissingClaimError): + decode_token(token) + + def test_decode_token_with_wrong_type(self): + """Test that verifying a token with wrong type raises TokenInvalidError""" + user_id = str(uuid.uuid4()) + token = create_access_token(subject=user_id) + + # Try to verify it as a refresh token + with pytest.raises(TokenInvalidError): + decode_token(token, verify_type="refresh") + + def test_decode_with_invalid_payload(self): + """Test that a token with invalid payload structure raises TokenInvalidError""" + # Create a token with an invalid payload structure - missing 'sub' which is required + # but including 'exp' to avoid the expiration check + now = datetime.now(timezone.utc) + payload = { + # Missing "sub" field which is required + "exp": int((now + timedelta(minutes=30)).timestamp()), + "iat": int(now.timestamp()), + "jti": str(uuid.uuid4()), + "invalid_field": "test" + } + + token = jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + # Should raise TokenMissingClaimError due to missing 'sub' + with pytest.raises(TokenMissingClaimError): + decode_token(token) + + # Create another token with invalid type for required field + payload = { + "sub": 123, # sub should be a string, not an integer + "exp": int((now + timedelta(minutes=30)).timestamp()), + } + + token = jwt.encode( + payload, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM + ) + + # Should raise TokenInvalidError due to ValidationError + with pytest.raises(TokenInvalidError): + decode_token(token) + + def test_get_token_data(self): + """Test extracting TokenData from a token""" + user_id = uuid.uuid4() + token = create_access_token( + subject=str(user_id), + claims={"is_superuser": True} + ) + + token_data = get_token_data(token) + + assert token_data.user_id == user_id + assert token_data.is_superuser is True \ No newline at end of file diff --git a/backend/tests/services/__init__.py b/backend/tests/services/__init__.py new file mode 100644 index 0000000..e69de29