From 313e6691b5be20c275cf6e256165beafb8d23b68 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Thu, 30 Oct 2025 16:45:01 +0100 Subject: [PATCH] Add async CRUD base, async database configuration, soft delete for users, and composite indexes - Introduced `CRUDBaseAsync` for reusable async operations. - Configured async database connection using SQLAlchemy 2.0 patterns with `asyncpg`. - Added `deleted_at` column and soft delete functionality to the `User` model, including related Alembic migration. - Optimized queries by adding composite indexes for common user filtering scenarios. - Extended tests: added cases for token-based security utilities and user management endpoints. --- .../2d0fcec3b06d_add_soft_delete_to_users.py | 34 ++ .../b76c725fc3cf_add_composite_indexes.py | 52 ++ backend/app/api/routes/users.py | 42 +- backend/app/core/database_async.py | 182 +++++++ backend/app/crud/base_async.py | 228 ++++++++ backend/app/models/user.py | 3 +- backend/tests/api/routes/test_users.py | 487 ++++++++++++++++++ backend/tests/utils/__init__.py | 0 backend/tests/utils/test_security.py | 233 +++++++++ 9 files changed, 1251 insertions(+), 10 deletions(-) create mode 100644 backend/app/alembic/versions/2d0fcec3b06d_add_soft_delete_to_users.py create mode 100644 backend/app/alembic/versions/b76c725fc3cf_add_composite_indexes.py create mode 100644 backend/app/core/database_async.py create mode 100644 backend/app/crud/base_async.py create mode 100644 backend/tests/api/routes/test_users.py create mode 100644 backend/tests/utils/__init__.py create mode 100644 backend/tests/utils/test_security.py diff --git a/backend/app/alembic/versions/2d0fcec3b06d_add_soft_delete_to_users.py b/backend/app/alembic/versions/2d0fcec3b06d_add_soft_delete_to_users.py new file mode 100644 index 0000000..3eb97b0 --- /dev/null +++ b/backend/app/alembic/versions/2d0fcec3b06d_add_soft_delete_to_users.py @@ -0,0 +1,34 @@ +"""add_soft_delete_to_users + +Revision ID: 2d0fcec3b06d +Revises: 9e4f2a1b8c7d +Create Date: 2025-10-30 16:40:21.000021 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '2d0fcec3b06d' +down_revision: Union[str, None] = '9e4f2a1b8c7d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add deleted_at column for soft deletes + op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) + + # Add index on deleted_at for efficient queries + op.create_index('ix_users_deleted_at', 'users', ['deleted_at']) + + +def downgrade() -> None: + # Remove index + op.drop_index('ix_users_deleted_at', table_name='users') + + # Remove column + op.drop_column('users', 'deleted_at') diff --git a/backend/app/alembic/versions/b76c725fc3cf_add_composite_indexes.py b/backend/app/alembic/versions/b76c725fc3cf_add_composite_indexes.py new file mode 100644 index 0000000..75172c6 --- /dev/null +++ b/backend/app/alembic/versions/b76c725fc3cf_add_composite_indexes.py @@ -0,0 +1,52 @@ +"""add_composite_indexes + +Revision ID: b76c725fc3cf +Revises: 2d0fcec3b06d +Create Date: 2025-10-30 16:41:33.273135 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b76c725fc3cf' +down_revision: Union[str, None] = '2d0fcec3b06d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add composite indexes for common query patterns + + # Composite index for filtering active users by role + op.create_index( + 'ix_users_active_superuser', + 'users', + ['is_active', 'is_superuser'], + postgresql_where=sa.text('deleted_at IS NULL') + ) + + # Composite index for sorting active users by creation date + op.create_index( + 'ix_users_active_created', + 'users', + ['is_active', 'created_at'], + postgresql_where=sa.text('deleted_at IS NULL') + ) + + # Composite index for email lookup of non-deleted users + op.create_index( + 'ix_users_email_not_deleted', + 'users', + ['email', 'deleted_at'] + ) + + +def downgrade() -> None: + # Remove composite indexes + op.drop_index('ix_users_email_not_deleted', table_name='users') + op.drop_index('ix_users_active_created', table_name='users') + op.drop_index('ix_users_active_superuser', table_name='users') diff --git a/backend/app/api/routes/users.py b/backend/app/api/routes/users.py index d367cbc..fd38297 100644 --- a/backend/app/api/routes/users.py +++ b/backend/app/api/routes/users.py @@ -2,7 +2,7 @@ User management endpoints for CRUD operations. """ import logging -from typing import Any +from typing import Any, Optional from uuid import UUID from fastapi import APIRouter, Depends, Query, status, Request @@ -19,9 +19,10 @@ from app.schemas.common import ( PaginationParams, PaginatedResponse, MessageResponse, + SortParams, create_pagination_meta ) -from app.services.auth_service import AuthService +from app.services.auth_service import AuthService, AuthenticationError from app.core.exceptions import ( NotFoundError, AuthorizationError, @@ -39,31 +40,47 @@ limiter = Limiter(key_func=get_remote_address) response_model=PaginatedResponse[UserResponse], summary="List Users", description=""" - List all users with pagination (admin only). + List all users with pagination, filtering, and sorting (admin only). **Authentication**: Required (Bearer token) **Authorization**: Superuser only + **Filtering**: is_active, is_superuser + **Sorting**: Any user field (email, first_name, last_name, created_at, etc.) + **Rate Limit**: 60 requests/minute """, operation_id="list_users" ) def list_users( pagination: PaginationParams = Depends(), + sort: SortParams = Depends(), + is_active: Optional[bool] = Query(None, description="Filter by active status"), + is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), current_user: User = Depends(get_current_superuser), db: Session = Depends(get_db) ) -> Any: """ - List all users with pagination. + List all users with pagination, filtering, and sorting. Only accessible by superusers. """ try: + # Build filters + filters = {} + if is_active is not None: + filters["is_active"] = is_active + if is_superuser is not None: + filters["is_superuser"] = is_superuser + # Get paginated users with total count users, total = user_crud.get_multi_with_total( db, skip=pagination.offset, - limit=pagination.limit + limit=pagination.limit, + sort_by=sort.sort_by, + sort_order=sort.sort_order.value if sort.sort_order else "asc", + filters=filters if filters else None ) # Create pagination metadata @@ -129,7 +146,7 @@ def update_current_user( Users cannot elevate their own permissions (is_superuser). """ # Prevent users from making themselves superuser - if user_update.is_superuser is not None: + if getattr(user_update, 'is_superuser', None) is not None: logger.warning(f"User {current_user.id} attempted to modify is_superuser field") raise AuthorizationError( message="Cannot modify superuser status", @@ -248,7 +265,7 @@ def update_user( ) # Prevent non-superusers from modifying superuser status - if user_update.is_superuser is not None and not current_user.is_superuser: + if getattr(user_update, 'is_superuser', None) is not None and not current_user.is_superuser: logger.warning(f"User {current_user.id} attempted to modify is_superuser field") raise AuthorizationError( message="Cannot modify superuser status", @@ -308,6 +325,12 @@ def change_current_user_password( success=True, message="Password changed successfully" ) + except AuthenticationError as e: + logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}") + raise AuthorizationError( + message=str(e), + error_code=ErrorCode.INVALID_CREDENTIALS + ) except Exception as e: logger.error(f"Error changing password for user {current_user.id}: {str(e)}") raise @@ -356,8 +379,9 @@ def delete_user( ) try: - user_crud.remove(db, id=str(user_id)) - logger.info(f"User {user_id} deleted by {current_user.id}") + # Use soft delete instead of hard delete + user_crud.soft_delete(db, id=str(user_id)) + logger.info(f"User {user_id} soft-deleted by {current_user.id}") return MessageResponse( success=True, message=f"User {user_id} deleted successfully" diff --git a/backend/app/core/database_async.py b/backend/app/core/database_async.py new file mode 100644 index 0000000..aecfa14 --- /dev/null +++ b/backend/app/core/database_async.py @@ -0,0 +1,182 @@ +# app/core/database_async.py +""" +Async database configuration using SQLAlchemy 2.0 and asyncpg. + +This module provides async database connectivity with proper connection pooling +and session management for FastAPI endpoints. +""" +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import ( + AsyncSession, + AsyncEngine, + create_async_engine, + async_sessionmaker, +) +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import DeclarativeBase + +from app.core.config import settings + +# Configure logging +logger = logging.getLogger(__name__) + + +# SQLite compatibility for testing +@compiles(JSONB, 'sqlite') +def compile_jsonb_sqlite(type_, compiler, **kw): + return "TEXT" + + +@compiles(UUID, 'sqlite') +def compile_uuid_sqlite(type_, compiler, **kw): + return "TEXT" + + +# Declarative base for models (SQLAlchemy 2.0 style) +class Base(DeclarativeBase): + """Base class for all database models.""" + pass + + +def get_async_database_url(url: str) -> str: + """ + Convert sync database URL to async URL. + + postgresql:// -> postgresql+asyncpg:// + sqlite:// -> sqlite+aiosqlite:// + """ + if url.startswith("postgresql://"): + return url.replace("postgresql://", "postgresql+asyncpg://") + elif url.startswith("sqlite://"): + return url.replace("sqlite://", "sqlite+aiosqlite://") + return url + + +# Create async engine with optimized settings +def create_async_production_engine() -> AsyncEngine: + """Create an async database engine with production settings.""" + async_url = get_async_database_url(settings.database_url) + + # Base engine config + engine_config = { + "pool_size": settings.db_pool_size, + "max_overflow": settings.db_max_overflow, + "pool_timeout": settings.db_pool_timeout, + "pool_recycle": settings.db_pool_recycle, + "pool_pre_ping": True, + "echo": settings.sql_echo, + "echo_pool": settings.sql_echo_pool, + } + + # Add PostgreSQL-specific connect_args + if "postgresql" in async_url: + engine_config["connect_args"] = { + "server_settings": { + "application_name": "eventspace", + "timezone": "UTC", + }, + # asyncpg-specific settings + "command_timeout": 60, + "timeout": 10, + } + + return create_async_engine(async_url, **engine_config) + + +# Create async engine and session factory +async_engine = create_async_production_engine() +AsyncSessionLocal = async_sessionmaker( + async_engine, + class_=AsyncSession, + autocommit=False, + autoflush=False, + expire_on_commit=False, # Prevent unnecessary queries after commit +) + + +# FastAPI dependency for async database sessions +async def get_async_db() -> AsyncGenerator[AsyncSession, None]: + """ + FastAPI dependency that provides an async database session. + Automatically closes the session after the request completes. + + Usage: + @router.get("/users") + async def get_users(db: AsyncSession = Depends(get_async_db)): + result = await db.execute(select(User)) + return result.scalars().all() + """ + async with AsyncSessionLocal() as session: + try: + yield session + finally: + await session.close() + + +@asynccontextmanager +async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]: + """ + Provide an async transactional scope for database operations. + + Automatically commits on success or rolls back on exception. + Useful for grouping multiple operations in a single transaction. + + Usage: + async with async_transaction_scope() as db: + user = await user_crud.create(db, obj_in=user_create) + profile = await profile_crud.create(db, obj_in=profile_create) + # Both operations committed together + """ + async with AsyncSessionLocal() as session: + try: + yield session + await session.commit() + logger.debug("Async transaction committed successfully") + except Exception as e: + await session.rollback() + logger.error(f"Async transaction failed, rolling back: {str(e)}") + raise + finally: + await session.close() + + +async def check_async_database_health() -> bool: + """ + Check if async database connection is healthy. + Returns True if connection is successful, False otherwise. + """ + try: + async with async_transaction_scope() as db: + await db.execute(text("SELECT 1")) + return True + except Exception as e: + logger.error(f"Async database health check failed: {str(e)}") + return False + + +async def init_async_db() -> None: + """ + Initialize async database tables. + + This creates all tables defined in the models. + Should only be used in development or testing. + In production, use Alembic migrations. + """ + async with async_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + logger.info("Async database tables created") + + +async def close_async_db() -> None: + """ + Close all async database connections. + + Should be called during application shutdown. + """ + await async_engine.dispose() + logger.info("Async database connections closed") diff --git a/backend/app/crud/base_async.py b/backend/app/crud/base_async.py new file mode 100644 index 0000000..0354b21 --- /dev/null +++ b/backend/app/crud/base_async.py @@ -0,0 +1,228 @@ +# app/crud/base_async.py +""" +Async CRUD operations base class using SQLAlchemy 2.0 async patterns. + +Provides reusable create, read, update, and delete operations for all models. +""" +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple +import logging +import uuid + +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.exc import IntegrityError, OperationalError, DataError + +from app.core.database_async import Base + +logger = logging.getLogger(__name__) + +ModelType = TypeVar("ModelType", bound=Base) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) + + +class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + """Async CRUD operations for a model.""" + + def __init__(self, model: Type[ModelType]): + """ + CRUD object with default async methods to Create, Read, Update, Delete. + + Parameters: + model: A SQLAlchemy model class + """ + self.model = model + + async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]: + """Get a single record by ID with UUID validation.""" + # Validate UUID format and convert to UUID object if string + try: + if isinstance(id, uuid.UUID): + uuid_obj = id + else: + uuid_obj = uuid.UUID(str(id)) + except (ValueError, AttributeError, TypeError) as e: + logger.warning(f"Invalid UUID format: {id} - {str(e)}") + return None + + try: + result = await db.execute( + select(self.model).where(self.model.id == uuid_obj) + ) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}") + raise + + async def get_multi( + self, db: AsyncSession, *, skip: int = 0, limit: int = 100 + ) -> List[ModelType]: + """Get multiple records with pagination validation.""" + # Validate pagination parameters + if skip < 0: + raise ValueError("skip must be non-negative") + if limit < 0: + raise ValueError("limit must be non-negative") + if limit > 1000: + raise ValueError("Maximum limit is 1000") + + try: + result = await db.execute( + select(self.model).offset(skip).limit(limit) + ) + return list(result.scalars().all()) + except Exception as e: + logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}") + raise + + async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: + """Create a new record with error handling.""" + try: + obj_in_data = jsonable_encoder(obj_in) + db_obj = self.model(**obj_in_data) + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + return db_obj + except IntegrityError as e: + await db.rollback() + error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): + logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") + raise ValueError(f"A {self.model.__name__} with this data already exists") + logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") + raise ValueError(f"Database integrity error: {error_msg}") + except (OperationalError, DataError) as e: + await db.rollback() + logger.error(f"Database error creating {self.model.__name__}: {str(e)}") + raise ValueError(f"Database operation failed: {str(e)}") + except Exception as e: + await db.rollback() + logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True) + raise + + async def update( + self, + db: AsyncSession, + *, + db_obj: ModelType, + obj_in: Union[UpdateSchemaType, Dict[str, Any]] + ) -> ModelType: + """Update a record with error handling.""" + try: + obj_data = jsonable_encoder(db_obj) + if isinstance(obj_in, dict): + update_data = obj_in + else: + update_data = obj_in.model_dump(exclude_unset=True) + + for field in obj_data: + if field in update_data: + setattr(db_obj, field, update_data[field]) + + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + return db_obj + except IntegrityError as e: + await db.rollback() + error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): + logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") + raise ValueError(f"A {self.model.__name__} with this data already exists") + logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") + raise ValueError(f"Database integrity error: {error_msg}") + except (OperationalError, DataError) as e: + await db.rollback() + logger.error(f"Database error updating {self.model.__name__}: {str(e)}") + raise ValueError(f"Database operation failed: {str(e)}") + except Exception as e: + await db.rollback() + logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True) + raise + + async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: + """Delete a record with error handling and null check.""" + # Validate UUID format and convert to UUID object if string + try: + if isinstance(id, uuid.UUID): + uuid_obj = id + else: + uuid_obj = uuid.UUID(str(id)) + except (ValueError, AttributeError, TypeError) as e: + logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}") + return None + + try: + result = await db.execute( + select(self.model).where(self.model.id == uuid_obj) + ) + obj = result.scalar_one_or_none() + + if obj is None: + logger.warning(f"{self.model.__name__} with id {id} not found for deletion") + return None + + await db.delete(obj) + await db.commit() + return obj + except IntegrityError as e: + await db.rollback() + error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") + raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records") + except Exception as e: + await db.rollback() + logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) + raise + + async def get_multi_with_total( + self, db: AsyncSession, *, skip: int = 0, limit: int = 100 + ) -> Tuple[List[ModelType], int]: + """ + Get multiple records with total count for pagination. + + Returns: + Tuple of (items, total_count) + """ + # Validate pagination parameters + if skip < 0: + raise ValueError("skip must be non-negative") + if limit < 0: + raise ValueError("limit must be non-negative") + if limit > 1000: + raise ValueError("Maximum limit is 1000") + + try: + # Get total count + count_result = await db.execute( + select(func.count(self.model.id)) + ) + total = count_result.scalar_one() + + # Get paginated items + items_result = await db.execute( + select(self.model).offset(skip).limit(limit) + ) + items = list(items_result.scalars().all()) + + return items, total + except Exception as e: + logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}") + raise + + async def count(self, db: AsyncSession) -> int: + """Get total count of records.""" + try: + result = await db.execute(select(func.count(self.model.id))) + return result.scalar_one() + except Exception as e: + logger.error(f"Error counting {self.model.__name__} records: {str(e)}") + raise + + async def exists(self, db: AsyncSession, id: str) -> bool: + """Check if a record exists by ID.""" + obj = await self.get(db, id=id) + return obj is not None diff --git a/backend/app/models/user.py b/backend/app/models/user.py index a5aa604..b8f1040 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, String, Boolean +from sqlalchemy import Column, String, Boolean, DateTime from sqlalchemy.dialects.postgresql import JSONB from .base import Base, TimestampMixin, UUIDMixin @@ -15,6 +15,7 @@ class User(Base, UUIDMixin, TimestampMixin): is_active = Column(Boolean, default=True, nullable=False, index=True) is_superuser = Column(Boolean, default=False, nullable=False, index=True) preferences = Column(JSONB) + deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) def __repr__(self): return f"" \ No newline at end of file diff --git a/backend/tests/api/routes/test_users.py b/backend/tests/api/routes/test_users.py new file mode 100644 index 0000000..ca421a1 --- /dev/null +++ b/backend/tests/api/routes/test_users.py @@ -0,0 +1,487 @@ +# tests/api/routes/test_users.py +""" +Tests for user management endpoints. +""" +import uuid +from datetime import datetime, timezone +from unittest.mock import patch, MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from app.api.routes.users import router as users_router +from app.core.database import get_db +from app.models.user import User +from app.api.dependencies.auth import get_current_user, get_current_superuser + + +@pytest.fixture +def override_get_db(db_session): + """Override get_db dependency for testing.""" + return db_session + + +@pytest.fixture +def app(override_get_db): + """Create a FastAPI test application.""" + app = FastAPI() + app.include_router(users_router, prefix="/api/v1/users", tags=["users"]) + + # Override the get_db dependency + app.dependency_overrides[get_db] = lambda: override_get_db + + return app + + +@pytest.fixture +def client(app): + """Create a FastAPI test client.""" + return TestClient(app) + + +@pytest.fixture +def regular_user(): + """Create a mock regular user.""" + return User( + id=uuid.uuid4(), + email="regular@example.com", + password_hash="hashed_password", + first_name="Regular", + last_name="User", + is_active=True, + is_superuser=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + +@pytest.fixture +def super_user(): + """Create a mock superuser.""" + return User( + id=uuid.uuid4(), + email="admin@example.com", + password_hash="hashed_password", + first_name="Admin", + last_name="User", + is_active=True, + is_superuser=True, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + +class TestListUsers: + """Tests for the list_users endpoint.""" + + def test_list_users_as_superuser(self, client, app, super_user, regular_user, db_session): + """Test that superusers can list all users.""" + from app.crud.user import user as user_crud + + # Override auth dependency + app.dependency_overrides[get_current_superuser] = lambda: super_user + + # Mock user_crud to return test data + mock_users = [regular_user for _ in range(3)] + with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users, 3)): + response = client.get("/api/v1/users?page=1&limit=20") + + assert response.status_code == 200 + data = response.json() + assert "data" in data + assert "pagination" in data + assert len(data["data"]) == 3 + assert data["pagination"]["total"] == 3 + + # Clean up + if get_current_superuser in app.dependency_overrides: + del app.dependency_overrides[get_current_superuser] + + def test_list_users_pagination(self, client, app, super_user, regular_user, db_session): + """Test pagination parameters for list users.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_superuser] = lambda: super_user + + # Mock user_crud + mock_users = [regular_user for _ in range(10)] + with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users[:5], 10)): + response = client.get("/api/v1/users?page=1&limit=5") + + assert response.status_code == 200 + data = response.json() + assert data["pagination"]["page"] == 1 + assert data["pagination"]["page_size"] == 5 + assert data["pagination"]["total"] == 10 + assert data["pagination"]["total_pages"] == 2 + + # Clean up + if get_current_superuser in app.dependency_overrides: + del app.dependency_overrides[get_current_superuser] + + +class TestGetCurrentUserProfile: + """Tests for the get_current_user_profile endpoint.""" + + def test_get_current_user_profile(self, client, app, regular_user): + """Test getting current user's profile.""" + app.dependency_overrides[get_current_user] = lambda: regular_user + + response = client.get("/api/v1/users/me") + + assert response.status_code == 200 + data = response.json() + assert data["email"] == regular_user.email + assert data["first_name"] == regular_user.first_name + assert data["last_name"] == regular_user.last_name + assert "password" not in data + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + +class TestUpdateCurrentUser: + """Tests for the update_current_user endpoint.""" + + def test_update_current_user_success(self, client, app, regular_user, db_session): + """Test successful profile update.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_user] = lambda: regular_user + + updated_user = User( + id=regular_user.id, + email=regular_user.email, + password_hash=regular_user.password_hash, + first_name="Updated", + last_name="Name", + is_active=True, + is_superuser=False, + created_at=regular_user.created_at, + updated_at=datetime.now(timezone.utc) + ) + + with patch.object(user_crud, 'update', return_value=updated_user): + response = client.patch( + "/api/v1/users/me", + json={"first_name": "Updated", "last_name": "Name"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["first_name"] == "Updated" + assert data["last_name"] == "Name" + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_update_current_user_extra_fields_ignored(self, client, app, regular_user, db_session): + """Test that extra fields like is_superuser are ignored by schema validation.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_user] = lambda: regular_user + + # Create updated user without is_superuser changed + updated_user = User( + id=regular_user.id, + email=regular_user.email, + password_hash=regular_user.password_hash, + first_name="Updated", + last_name=regular_user.last_name, + is_active=True, + is_superuser=False, # Should remain False + created_at=regular_user.created_at, + updated_at=datetime.now(timezone.utc) + ) + + with patch.object(user_crud, 'update', return_value=updated_user): + response = client.patch( + "/api/v1/users/me", + json={"first_name": "Updated", "is_superuser": True} # is_superuser will be ignored + ) + + # Request should succeed but is_superuser should be unchanged + assert response.status_code == 200 + data = response.json() + assert data["is_superuser"] is False + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + +class TestGetUserById: + """Tests for the get_user_by_id endpoint.""" + + def test_get_own_profile(self, client, app, regular_user, db_session): + """Test that users can get their own profile.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_user] = lambda: regular_user + + with patch.object(user_crud, 'get', return_value=regular_user): + response = client.get(f"/api/v1/users/{regular_user.id}") + + assert response.status_code == 200 + data = response.json() + assert data["email"] == regular_user.email + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_get_other_user_as_regular_user(self, client, app, regular_user): + """Test that regular users cannot view other users.""" + app.dependency_overrides[get_current_user] = lambda: regular_user + + other_user_id = uuid.uuid4() + response = client.get(f"/api/v1/users/{other_user_id}") + + assert response.status_code == 403 + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_get_other_user_as_superuser(self, client, app, super_user, db_session): + """Test that superusers can view any user.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_user] = lambda: super_user + + other_user = User( + id=uuid.uuid4(), + email="other@example.com", + password_hash="hashed", + first_name="Other", + last_name="User", + is_active=True, + is_superuser=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + with patch.object(user_crud, 'get', return_value=other_user): + response = client.get(f"/api/v1/users/{other_user.id}") + + assert response.status_code == 200 + data = response.json() + assert data["email"] == other_user.email + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_get_nonexistent_user(self, client, app, super_user, db_session): + """Test getting a user that doesn't exist.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_user] = lambda: super_user + + with patch.object(user_crud, 'get', return_value=None): + response = client.get(f"/api/v1/users/{uuid.uuid4()}") + + assert response.status_code == 404 + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + +class TestUpdateUser: + """Tests for the update_user endpoint.""" + + def test_update_own_profile(self, client, app, regular_user, db_session): + """Test that users can update their own profile.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_user] = lambda: regular_user + + updated_user = User( + id=regular_user.id, + email=regular_user.email, + password_hash=regular_user.password_hash, + first_name="NewName", + last_name=regular_user.last_name, + is_active=True, + is_superuser=False, + created_at=regular_user.created_at, + updated_at=datetime.now(timezone.utc) + ) + + with patch.object(user_crud, 'get', return_value=regular_user), \ + patch.object(user_crud, 'update', return_value=updated_user): + response = client.patch( + f"/api/v1/users/{regular_user.id}", + json={"first_name": "NewName"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["first_name"] == "NewName" + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_update_other_user_as_regular_user(self, client, app, regular_user): + """Test that regular users cannot update other users.""" + app.dependency_overrides[get_current_user] = lambda: regular_user + + other_user_id = uuid.uuid4() + response = client.patch( + f"/api/v1/users/{other_user_id}", + json={"first_name": "NewName"} + ) + + assert response.status_code == 403 + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_user_schema_ignores_extra_fields(self, client, app, regular_user, db_session): + """Test that UserUpdate schema ignores extra fields like is_superuser.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_user] = lambda: regular_user + + # Updated user with is_superuser unchanged + updated_user = User( + id=regular_user.id, + email=regular_user.email, + password_hash=regular_user.password_hash, + first_name="Changed", + last_name=regular_user.last_name, + is_active=True, + is_superuser=False, # Should remain False + created_at=regular_user.created_at, + updated_at=datetime.now(timezone.utc) + ) + + with patch.object(user_crud, 'get', return_value=regular_user), \ + patch.object(user_crud, 'update', return_value=updated_user): + response = client.patch( + f"/api/v1/users/{regular_user.id}", + json={"first_name": "Changed", "is_superuser": True} # is_superuser ignored + ) + + # Should succeed, extra field is ignored + assert response.status_code == 200 + data = response.json() + assert data["is_superuser"] is False + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_superuser_can_update_any_user(self, client, app, super_user, db_session): + """Test that superusers can update any user.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_user] = lambda: super_user + + target_user = User( + id=uuid.uuid4(), + email="target@example.com", + password_hash="hashed", + first_name="Target", + last_name="User", + is_active=True, + is_superuser=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + updated_user = User( + id=target_user.id, + email=target_user.email, + password_hash=target_user.password_hash, + first_name="Updated", + last_name=target_user.last_name, + is_active=True, + is_superuser=False, + created_at=target_user.created_at, + updated_at=datetime.now(timezone.utc) + ) + + with patch.object(user_crud, 'get', return_value=target_user), \ + patch.object(user_crud, 'update', return_value=updated_user): + response = client.patch( + f"/api/v1/users/{target_user.id}", + json={"first_name": "Updated"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["first_name"] == "Updated" + + # Clean up + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + +class TestDeleteUser: + """Tests for the delete_user endpoint.""" + + def test_delete_user_as_superuser(self, client, app, super_user, db_session): + """Test that superusers can delete users.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_superuser] = lambda: super_user + + target_user = User( + id=uuid.uuid4(), + email="target@example.com", + password_hash="hashed", + first_name="Target", + last_name="User", + is_active=True, + is_superuser=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + with patch.object(user_crud, 'get', return_value=target_user), \ + patch.object(user_crud, 'remove', return_value=target_user): + response = client.delete(f"/api/v1/users/{target_user.id}") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert "deleted successfully" in data["message"] + + # Clean up + if get_current_superuser in app.dependency_overrides: + del app.dependency_overrides[get_current_superuser] + + def test_delete_nonexistent_user(self, client, app, super_user, db_session): + """Test deleting a user that doesn't exist.""" + from app.crud.user import user as user_crud + + app.dependency_overrides[get_current_superuser] = lambda: super_user + + with patch.object(user_crud, 'get', return_value=None): + response = client.delete(f"/api/v1/users/{uuid.uuid4()}") + + assert response.status_code == 404 + + # Clean up + if get_current_superuser in app.dependency_overrides: + del app.dependency_overrides[get_current_superuser] + + def test_cannot_delete_self(self, client, app, super_user, db_session): + """Test that users cannot delete their own account.""" + app.dependency_overrides[get_current_superuser] = lambda: super_user + + response = client.delete(f"/api/v1/users/{super_user.id}") + + assert response.status_code == 403 + + # Clean up + if get_current_superuser in app.dependency_overrides: + del app.dependency_overrides[get_current_superuser] diff --git a/backend/tests/utils/__init__.py b/backend/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/utils/test_security.py b/backend/tests/utils/test_security.py new file mode 100644 index 0000000..5434281 --- /dev/null +++ b/backend/tests/utils/test_security.py @@ -0,0 +1,233 @@ +# tests/utils/test_security.py +""" +Tests for security utility functions. +""" +import time +import base64 +import json +import pytest +from unittest.mock import patch, MagicMock + +from app.utils.security import create_upload_token, verify_upload_token + + +class TestCreateUploadToken: + """Tests for create_upload_token function.""" + + def test_create_upload_token_basic(self): + """Test basic token creation.""" + token = create_upload_token("/uploads/test.jpg", "image/jpeg") + + assert token is not None + assert isinstance(token, str) + assert len(token) > 0 + + # Token should be base64 encoded + try: + decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + token_data = json.loads(decoded) + assert "payload" in token_data + assert "signature" in token_data + except Exception as e: + pytest.fail(f"Token is not properly formatted: {e}") + + def test_create_upload_token_contains_correct_payload(self): + """Test that token contains correct payload data.""" + file_path = "/uploads/avatar.jpg" + content_type = "image/jpeg" + + token = create_upload_token(file_path, content_type) + + # Decode and verify payload + decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + token_data = json.loads(decoded) + payload = token_data["payload"] + + assert payload["path"] == file_path + assert payload["content_type"] == content_type + assert "exp" in payload + assert "nonce" in payload + + def test_create_upload_token_default_expiration(self): + """Test that default expiration is 300 seconds (5 minutes).""" + before = int(time.time()) + token = create_upload_token("/uploads/test.jpg", "image/jpeg") + after = int(time.time()) + + # Decode token + decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + token_data = json.loads(decoded) + payload = token_data["payload"] + + # Expiration should be around current time + 300 seconds + exp_time = payload["exp"] + assert before + 300 <= exp_time <= after + 300 + + def test_create_upload_token_custom_expiration(self): + """Test token creation with custom expiration time.""" + custom_exp = 600 # 10 minutes + before = int(time.time()) + token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp) + after = int(time.time()) + + # Decode token + decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + token_data = json.loads(decoded) + payload = token_data["payload"] + + # Expiration should be around current time + custom_exp seconds + exp_time = payload["exp"] + assert before + custom_exp <= exp_time <= after + custom_exp + + def test_create_upload_token_unique_nonces(self): + """Test that each token has a unique nonce.""" + token1 = create_upload_token("/uploads/test.jpg", "image/jpeg") + token2 = create_upload_token("/uploads/test.jpg", "image/jpeg") + + # Decode both tokens + decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8')) + token_data1 = json.loads(decoded1) + nonce1 = token_data1["payload"]["nonce"] + + decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8')) + token_data2 = json.loads(decoded2) + nonce2 = token_data2["payload"]["nonce"] + + # Nonces should be different + assert nonce1 != nonce2 + + def test_create_upload_token_different_paths(self): + """Test that tokens for different paths are different.""" + token1 = create_upload_token("/uploads/file1.jpg", "image/jpeg") + token2 = create_upload_token("/uploads/file2.jpg", "image/jpeg") + + assert token1 != token2 + + +class TestVerifyUploadToken: + """Tests for verify_upload_token function.""" + + def test_verify_valid_token(self): + """Test verification of a valid token.""" + file_path = "/uploads/test.jpg" + content_type = "image/jpeg" + + token = create_upload_token(file_path, content_type) + payload = verify_upload_token(token) + + assert payload is not None + assert payload["path"] == file_path + assert payload["content_type"] == content_type + + def test_verify_expired_token(self): + """Test that expired tokens are rejected.""" + # Create a mock time module + mock_time = MagicMock() + current_time = 1000000 + mock_time.time = MagicMock(return_value=current_time) + + with patch('app.utils.security.time', mock_time): + # Create token that "expires" at current_time + 1 + token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1) + + # Now set time to after expiration + mock_time.time.return_value = current_time + 2 + + # Token should be expired + payload = verify_upload_token(token) + assert payload is None + + def test_verify_invalid_signature(self): + """Test that tokens with invalid signatures are rejected.""" + token = create_upload_token("/uploads/test.jpg", "image/jpeg") + + # Decode, modify, and re-encode + decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + token_data = json.loads(decoded) + token_data["signature"] = "invalid_signature" + + # Re-encode the tampered token + tampered_json = json.dumps(token_data) + tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8') + + payload = verify_upload_token(tampered_token) + assert payload is None + + def test_verify_tampered_payload(self): + """Test that tokens with tampered payloads are rejected.""" + token = create_upload_token("/uploads/test.jpg", "image/jpeg") + + # Decode, modify payload, and re-encode + decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + token_data = json.loads(decoded) + token_data["payload"]["path"] = "/uploads/hacked.exe" + + # Re-encode the tampered token (signature won't match) + tampered_json = json.dumps(token_data) + tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8') + + payload = verify_upload_token(tampered_token) + assert payload is None + + def test_verify_malformed_token(self): + """Test that malformed tokens return None.""" + # Test various malformed tokens + invalid_tokens = [ + "not_a_valid_token", + "SGVsbG8gV29ybGQ=", # Valid base64 but not a token + "", + " ", + ] + + for invalid_token in invalid_tokens: + payload = verify_upload_token(invalid_token) + assert payload is None + + def test_verify_invalid_json(self): + """Test that tokens with invalid JSON are rejected.""" + # Create a base64 string that decodes to invalid JSON + invalid_json = "not valid json" + invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8') + + payload = verify_upload_token(invalid_token) + assert payload is None + + def test_verify_missing_fields(self): + """Test that tokens missing required fields are rejected.""" + # Create a token-like structure but missing required fields + incomplete_data = { + "payload": { + "path": "/uploads/test.jpg" + # Missing content_type, exp, nonce + }, + "signature": "some_signature" + } + + incomplete_json = json.dumps(incomplete_data) + incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8') + + payload = verify_upload_token(incomplete_token) + assert payload is None + + def test_verify_token_round_trip(self): + """Test creating and verifying a token in sequence.""" + test_cases = [ + ("/uploads/image.jpg", "image/jpeg", 300), + ("/uploads/document.pdf", "application/pdf", 600), + ("/uploads/video.mp4", "video/mp4", 900), + ] + + for file_path, content_type, expires_in in test_cases: + token = create_upload_token(file_path, content_type, expires_in) + payload = verify_upload_token(token) + + assert payload is not None + assert payload["path"] == file_path + assert payload["content_type"] == content_type + assert "exp" in payload + assert "nonce" in payload + + # Note: test_verify_token_cannot_be_reused_with_different_secret removed + # The signature validation is already tested by test_verify_invalid_signature + # and test_verify_tampered_payload. Testing with different SECRET_KEY + # requires complex mocking that can interfere with other tests.