Add async CRUD base, async database configuration, soft delete for users, and composite indexes

- Introduced `CRUDBaseAsync` for reusable async operations.
- Configured async database connection using SQLAlchemy 2.0 patterns with `asyncpg`.
- Added `deleted_at` column and soft delete functionality to the `User` model, including related Alembic migration.
- Optimized queries by adding composite indexes for common user filtering scenarios.
- Extended tests: added cases for token-based security utilities and user management endpoints.
This commit is contained in:
Felipe Cardoso
2025-10-30 16:45:01 +01:00
parent c684f2ba95
commit 313e6691b5
9 changed files with 1251 additions and 10 deletions

View File

@@ -0,0 +1,34 @@
"""add_soft_delete_to_users
Revision ID: 2d0fcec3b06d
Revises: 9e4f2a1b8c7d
Create Date: 2025-10-30 16:40:21.000021
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '2d0fcec3b06d'
down_revision: Union[str, None] = '9e4f2a1b8c7d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add deleted_at column for soft deletes
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
# Add index on deleted_at for efficient queries
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
def downgrade() -> None:
# Remove index
op.drop_index('ix_users_deleted_at', table_name='users')
# Remove column
op.drop_column('users', 'deleted_at')

View File

@@ -0,0 +1,52 @@
"""add_composite_indexes
Revision ID: b76c725fc3cf
Revises: 2d0fcec3b06d
Create Date: 2025-10-30 16:41:33.273135
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b76c725fc3cf'
down_revision: Union[str, None] = '2d0fcec3b06d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add composite indexes for common query patterns
# Composite index for filtering active users by role
op.create_index(
'ix_users_active_superuser',
'users',
['is_active', 'is_superuser'],
postgresql_where=sa.text('deleted_at IS NULL')
)
# Composite index for sorting active users by creation date
op.create_index(
'ix_users_active_created',
'users',
['is_active', 'created_at'],
postgresql_where=sa.text('deleted_at IS NULL')
)
# Composite index for email lookup of non-deleted users
op.create_index(
'ix_users_email_not_deleted',
'users',
['email', 'deleted_at']
)
def downgrade() -> None:
# Remove composite indexes
op.drop_index('ix_users_email_not_deleted', table_name='users')
op.drop_index('ix_users_active_created', table_name='users')
op.drop_index('ix_users_active_superuser', table_name='users')

View File

@@ -2,7 +2,7 @@
User management endpoints for CRUD operations.
"""
import logging
from typing import Any
from typing import Any, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query, status, Request
@@ -19,9 +19,10 @@ from app.schemas.common import (
PaginationParams,
PaginatedResponse,
MessageResponse,
SortParams,
create_pagination_meta
)
from app.services.auth_service import AuthService
from app.services.auth_service import AuthService, AuthenticationError
from app.core.exceptions import (
NotFoundError,
AuthorizationError,
@@ -39,31 +40,47 @@ limiter = Limiter(key_func=get_remote_address)
response_model=PaginatedResponse[UserResponse],
summary="List Users",
description="""
List all users with pagination (admin only).
List all users with pagination, filtering, and sorting (admin only).
**Authentication**: Required (Bearer token)
**Authorization**: Superuser only
**Filtering**: is_active, is_superuser
**Sorting**: Any user field (email, first_name, last_name, created_at, etc.)
**Rate Limit**: 60 requests/minute
""",
operation_id="list_users"
)
def list_users(
pagination: PaginationParams = Depends(),
sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
current_user: User = Depends(get_current_superuser),
db: Session = Depends(get_db)
) -> Any:
"""
List all users with pagination.
List all users with pagination, filtering, and sorting.
Only accessible by superusers.
"""
try:
# Build filters
filters = {}
if is_active is not None:
filters["is_active"] = is_active
if is_superuser is not None:
filters["is_superuser"] = is_superuser
# Get paginated users with total count
users, total = user_crud.get_multi_with_total(
db,
skip=pagination.offset,
limit=pagination.limit
limit=pagination.limit,
sort_by=sort.sort_by,
sort_order=sort.sort_order.value if sort.sort_order else "asc",
filters=filters if filters else None
)
# Create pagination metadata
@@ -129,7 +146,7 @@ def update_current_user(
Users cannot elevate their own permissions (is_superuser).
"""
# Prevent users from making themselves superuser
if user_update.is_superuser is not None:
if getattr(user_update, 'is_superuser', None) is not None:
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
raise AuthorizationError(
message="Cannot modify superuser status",
@@ -248,7 +265,7 @@ def update_user(
)
# Prevent non-superusers from modifying superuser status
if user_update.is_superuser is not None and not current_user.is_superuser:
if getattr(user_update, 'is_superuser', None) is not None and not current_user.is_superuser:
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
raise AuthorizationError(
message="Cannot modify superuser status",
@@ -308,6 +325,12 @@ def change_current_user_password(
success=True,
message="Password changed successfully"
)
except AuthenticationError as e:
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
raise AuthorizationError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
except Exception as e:
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
raise
@@ -356,8 +379,9 @@ def delete_user(
)
try:
user_crud.remove(db, id=str(user_id))
logger.info(f"User {user_id} deleted by {current_user.id}")
# Use soft delete instead of hard delete
user_crud.soft_delete(db, id=str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
return MessageResponse(
success=True,
message=f"User {user_id} deleted successfully"

View File

@@ -0,0 +1,182 @@
# app/core/database_async.py
"""
Async database configuration using SQLAlchemy 2.0 and asyncpg.
This module provides async database connectivity with proper connection pooling
and session management for FastAPI endpoints.
"""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import text
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
create_async_engine,
async_sessionmaker,
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import DeclarativeBase
from app.core.config import settings
# Configure logging
logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
# Declarative base for models (SQLAlchemy 2.0 style)
class Base(DeclarativeBase):
"""Base class for all database models."""
pass
def get_async_database_url(url: str) -> str:
"""
Convert sync database URL to async URL.
postgresql:// -> postgresql+asyncpg://
sqlite:// -> sqlite+aiosqlite://
"""
if url.startswith("postgresql://"):
return url.replace("postgresql://", "postgresql+asyncpg://")
elif url.startswith("sqlite://"):
return url.replace("sqlite://", "sqlite+aiosqlite://")
return url
# Create async engine with optimized settings
def create_async_production_engine() -> AsyncEngine:
"""Create an async database engine with production settings."""
async_url = get_async_database_url(settings.database_url)
# Base engine config
engine_config = {
"pool_size": settings.db_pool_size,
"max_overflow": settings.db_max_overflow,
"pool_timeout": settings.db_pool_timeout,
"pool_recycle": settings.db_pool_recycle,
"pool_pre_ping": True,
"echo": settings.sql_echo,
"echo_pool": settings.sql_echo_pool,
}
# Add PostgreSQL-specific connect_args
if "postgresql" in async_url:
engine_config["connect_args"] = {
"server_settings": {
"application_name": "eventspace",
"timezone": "UTC",
},
# asyncpg-specific settings
"command_timeout": 60,
"timeout": 10,
}
return create_async_engine(async_url, **engine_config)
# Create async engine and session factory
async_engine = create_async_production_engine()
AsyncSessionLocal = async_sessionmaker(
async_engine,
class_=AsyncSession,
autocommit=False,
autoflush=False,
expire_on_commit=False, # Prevent unnecessary queries after commit
)
# FastAPI dependency for async database sessions
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""
FastAPI dependency that provides an async database session.
Automatically closes the session after the request completes.
Usage:
@router.get("/users")
async def get_users(db: AsyncSession = Depends(get_async_db)):
result = await db.execute(select(User))
return result.scalars().all()
"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
@asynccontextmanager
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
"""
Provide an async transactional scope for database operations.
Automatically commits on success or rolls back on exception.
Useful for grouping multiple operations in a single transaction.
Usage:
async with async_transaction_scope() as db:
user = await user_crud.create(db, obj_in=user_create)
profile = await profile_crud.create(db, obj_in=profile_create)
# Both operations committed together
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
logger.debug("Async transaction committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async transaction failed, rolling back: {str(e)}")
raise
finally:
await session.close()
async def check_async_database_health() -> bool:
"""
Check if async database connection is healthy.
Returns True if connection is successful, False otherwise.
"""
try:
async with async_transaction_scope() as db:
await db.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Async database health check failed: {str(e)}")
return False
async def init_async_db() -> None:
"""
Initialize async database tables.
This creates all tables defined in the models.
Should only be used in development or testing.
In production, use Alembic migrations.
"""
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Async database tables created")
async def close_async_db() -> None:
"""
Close all async database connections.
Should be called during application shutdown.
"""
await async_engine.dispose()
logger.info("Async database connections closed")

View File

@@ -0,0 +1,228 @@
# app/crud/base_async.py
"""
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
import logging
import uuid
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from app.core.database_async import Base
logger = logging.getLogger(__name__)
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default async methods to Create, Read, Update, Delete.
Parameters:
model: A SQLAlchemy model class
"""
self.model = model
async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]:
"""Get a single record by ID with UUID validation."""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
async def get_multi(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
"""Get multiple records with pagination validation."""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
result = await db.execute(
select(self.model).offset(skip).limit(limit)
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
"""Create a new record with error handling."""
try:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
"""Update a record with error handling."""
try:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
raise
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None
await db.delete(obj)
await db.commit()
return obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
except Exception as e:
await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
async def get_multi_with_total(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
) -> Tuple[List[ModelType], int]:
"""
Get multiple records with total count for pagination.
Returns:
Tuple of (items, total_count)
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
# Get total count
count_result = await db.execute(
select(func.count(self.model.id))
)
total = count_result.scalar_one()
# Get paginated items
items_result = await db.execute(
select(self.model).offset(skip).limit(limit)
)
items = list(items_result.scalars().all())
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise
async def count(self, db: AsyncSession) -> int:
"""Get total count of records."""
try:
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
"""Check if a record exists by ID."""
obj = await self.get(db, id=id)
return obj is not None

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, String, Boolean
from sqlalchemy import Column, String, Boolean, DateTime
from sqlalchemy.dialects.postgresql import JSONB
from .base import Base, TimestampMixin, UUIDMixin
@@ -15,6 +15,7 @@ class User(Base, UUIDMixin, TimestampMixin):
is_active = Column(Boolean, default=True, nullable=False, index=True)
is_superuser = Column(Boolean, default=False, nullable=False, index=True)
preferences = Column(JSONB)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
def __repr__(self):
return f"<User {self.email}>"

View File

@@ -0,0 +1,487 @@
# tests/api/routes/test_users.py
"""
Tests for user management endpoints.
"""
import uuid
from datetime import datetime, timezone
from unittest.mock import patch, MagicMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.api.routes.users import router as users_router
from app.core.database import get_db
from app.models.user import User
from app.api.dependencies.auth import get_current_user, get_current_superuser
@pytest.fixture
def override_get_db(db_session):
"""Override get_db dependency for testing."""
return db_session
@pytest.fixture
def app(override_get_db):
"""Create a FastAPI test application."""
app = FastAPI()
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
# Override the get_db dependency
app.dependency_overrides[get_db] = lambda: override_get_db
return app
@pytest.fixture
def client(app):
"""Create a FastAPI test client."""
return TestClient(app)
@pytest.fixture
def regular_user():
"""Create a mock regular user."""
return User(
id=uuid.uuid4(),
email="regular@example.com",
password_hash="hashed_password",
first_name="Regular",
last_name="User",
is_active=True,
is_superuser=False,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
@pytest.fixture
def super_user():
"""Create a mock superuser."""
return User(
id=uuid.uuid4(),
email="admin@example.com",
password_hash="hashed_password",
first_name="Admin",
last_name="User",
is_active=True,
is_superuser=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
class TestListUsers:
"""Tests for the list_users endpoint."""
def test_list_users_as_superuser(self, client, app, super_user, regular_user, db_session):
"""Test that superusers can list all users."""
from app.crud.user import user as user_crud
# Override auth dependency
app.dependency_overrides[get_current_superuser] = lambda: super_user
# Mock user_crud to return test data
mock_users = [regular_user for _ in range(3)]
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users, 3)):
response = client.get("/api/v1/users?page=1&limit=20")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "pagination" in data
assert len(data["data"]) == 3
assert data["pagination"]["total"] == 3
# Clean up
if get_current_superuser in app.dependency_overrides:
del app.dependency_overrides[get_current_superuser]
def test_list_users_pagination(self, client, app, super_user, regular_user, db_session):
"""Test pagination parameters for list users."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_superuser] = lambda: super_user
# Mock user_crud
mock_users = [regular_user for _ in range(10)]
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users[:5], 10)):
response = client.get("/api/v1/users?page=1&limit=5")
assert response.status_code == 200
data = response.json()
assert data["pagination"]["page"] == 1
assert data["pagination"]["page_size"] == 5
assert data["pagination"]["total"] == 10
assert data["pagination"]["total_pages"] == 2
# Clean up
if get_current_superuser in app.dependency_overrides:
del app.dependency_overrides[get_current_superuser]
class TestGetCurrentUserProfile:
"""Tests for the get_current_user_profile endpoint."""
def test_get_current_user_profile(self, client, app, regular_user):
"""Test getting current user's profile."""
app.dependency_overrides[get_current_user] = lambda: regular_user
response = client.get("/api/v1/users/me")
assert response.status_code == 200
data = response.json()
assert data["email"] == regular_user.email
assert data["first_name"] == regular_user.first_name
assert data["last_name"] == regular_user.last_name
assert "password" not in data
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
class TestUpdateCurrentUser:
"""Tests for the update_current_user endpoint."""
def test_update_current_user_success(self, client, app, regular_user, db_session):
"""Test successful profile update."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_user] = lambda: regular_user
updated_user = User(
id=regular_user.id,
email=regular_user.email,
password_hash=regular_user.password_hash,
first_name="Updated",
last_name="Name",
is_active=True,
is_superuser=False,
created_at=regular_user.created_at,
updated_at=datetime.now(timezone.utc)
)
with patch.object(user_crud, 'update', return_value=updated_user):
response = client.patch(
"/api/v1/users/me",
json={"first_name": "Updated", "last_name": "Name"}
)
assert response.status_code == 200
data = response.json()
assert data["first_name"] == "Updated"
assert data["last_name"] == "Name"
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
def test_update_current_user_extra_fields_ignored(self, client, app, regular_user, db_session):
"""Test that extra fields like is_superuser are ignored by schema validation."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_user] = lambda: regular_user
# Create updated user without is_superuser changed
updated_user = User(
id=regular_user.id,
email=regular_user.email,
password_hash=regular_user.password_hash,
first_name="Updated",
last_name=regular_user.last_name,
is_active=True,
is_superuser=False, # Should remain False
created_at=regular_user.created_at,
updated_at=datetime.now(timezone.utc)
)
with patch.object(user_crud, 'update', return_value=updated_user):
response = client.patch(
"/api/v1/users/me",
json={"first_name": "Updated", "is_superuser": True} # is_superuser will be ignored
)
# Request should succeed but is_superuser should be unchanged
assert response.status_code == 200
data = response.json()
assert data["is_superuser"] is False
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
class TestGetUserById:
"""Tests for the get_user_by_id endpoint."""
def test_get_own_profile(self, client, app, regular_user, db_session):
"""Test that users can get their own profile."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_user] = lambda: regular_user
with patch.object(user_crud, 'get', return_value=regular_user):
response = client.get(f"/api/v1/users/{regular_user.id}")
assert response.status_code == 200
data = response.json()
assert data["email"] == regular_user.email
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
def test_get_other_user_as_regular_user(self, client, app, regular_user):
"""Test that regular users cannot view other users."""
app.dependency_overrides[get_current_user] = lambda: regular_user
other_user_id = uuid.uuid4()
response = client.get(f"/api/v1/users/{other_user_id}")
assert response.status_code == 403
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
def test_get_other_user_as_superuser(self, client, app, super_user, db_session):
"""Test that superusers can view any user."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_user] = lambda: super_user
other_user = User(
id=uuid.uuid4(),
email="other@example.com",
password_hash="hashed",
first_name="Other",
last_name="User",
is_active=True,
is_superuser=False,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
with patch.object(user_crud, 'get', return_value=other_user):
response = client.get(f"/api/v1/users/{other_user.id}")
assert response.status_code == 200
data = response.json()
assert data["email"] == other_user.email
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
def test_get_nonexistent_user(self, client, app, super_user, db_session):
"""Test getting a user that doesn't exist."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_user] = lambda: super_user
with patch.object(user_crud, 'get', return_value=None):
response = client.get(f"/api/v1/users/{uuid.uuid4()}")
assert response.status_code == 404
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
class TestUpdateUser:
"""Tests for the update_user endpoint."""
def test_update_own_profile(self, client, app, regular_user, db_session):
"""Test that users can update their own profile."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_user] = lambda: regular_user
updated_user = User(
id=regular_user.id,
email=regular_user.email,
password_hash=regular_user.password_hash,
first_name="NewName",
last_name=regular_user.last_name,
is_active=True,
is_superuser=False,
created_at=regular_user.created_at,
updated_at=datetime.now(timezone.utc)
)
with patch.object(user_crud, 'get', return_value=regular_user), \
patch.object(user_crud, 'update', return_value=updated_user):
response = client.patch(
f"/api/v1/users/{regular_user.id}",
json={"first_name": "NewName"}
)
assert response.status_code == 200
data = response.json()
assert data["first_name"] == "NewName"
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
def test_update_other_user_as_regular_user(self, client, app, regular_user):
"""Test that regular users cannot update other users."""
app.dependency_overrides[get_current_user] = lambda: regular_user
other_user_id = uuid.uuid4()
response = client.patch(
f"/api/v1/users/{other_user_id}",
json={"first_name": "NewName"}
)
assert response.status_code == 403
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
def test_user_schema_ignores_extra_fields(self, client, app, regular_user, db_session):
"""Test that UserUpdate schema ignores extra fields like is_superuser."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_user] = lambda: regular_user
# Updated user with is_superuser unchanged
updated_user = User(
id=regular_user.id,
email=regular_user.email,
password_hash=regular_user.password_hash,
first_name="Changed",
last_name=regular_user.last_name,
is_active=True,
is_superuser=False, # Should remain False
created_at=regular_user.created_at,
updated_at=datetime.now(timezone.utc)
)
with patch.object(user_crud, 'get', return_value=regular_user), \
patch.object(user_crud, 'update', return_value=updated_user):
response = client.patch(
f"/api/v1/users/{regular_user.id}",
json={"first_name": "Changed", "is_superuser": True} # is_superuser ignored
)
# Should succeed, extra field is ignored
assert response.status_code == 200
data = response.json()
assert data["is_superuser"] is False
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
def test_superuser_can_update_any_user(self, client, app, super_user, db_session):
"""Test that superusers can update any user."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_user] = lambda: super_user
target_user = User(
id=uuid.uuid4(),
email="target@example.com",
password_hash="hashed",
first_name="Target",
last_name="User",
is_active=True,
is_superuser=False,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
updated_user = User(
id=target_user.id,
email=target_user.email,
password_hash=target_user.password_hash,
first_name="Updated",
last_name=target_user.last_name,
is_active=True,
is_superuser=False,
created_at=target_user.created_at,
updated_at=datetime.now(timezone.utc)
)
with patch.object(user_crud, 'get', return_value=target_user), \
patch.object(user_crud, 'update', return_value=updated_user):
response = client.patch(
f"/api/v1/users/{target_user.id}",
json={"first_name": "Updated"}
)
assert response.status_code == 200
data = response.json()
assert data["first_name"] == "Updated"
# Clean up
if get_current_user in app.dependency_overrides:
del app.dependency_overrides[get_current_user]
class TestDeleteUser:
"""Tests for the delete_user endpoint."""
def test_delete_user_as_superuser(self, client, app, super_user, db_session):
"""Test that superusers can delete users."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_superuser] = lambda: super_user
target_user = User(
id=uuid.uuid4(),
email="target@example.com",
password_hash="hashed",
first_name="Target",
last_name="User",
is_active=True,
is_superuser=False,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
with patch.object(user_crud, 'get', return_value=target_user), \
patch.object(user_crud, 'remove', return_value=target_user):
response = client.delete(f"/api/v1/users/{target_user.id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "deleted successfully" in data["message"]
# Clean up
if get_current_superuser in app.dependency_overrides:
del app.dependency_overrides[get_current_superuser]
def test_delete_nonexistent_user(self, client, app, super_user, db_session):
"""Test deleting a user that doesn't exist."""
from app.crud.user import user as user_crud
app.dependency_overrides[get_current_superuser] = lambda: super_user
with patch.object(user_crud, 'get', return_value=None):
response = client.delete(f"/api/v1/users/{uuid.uuid4()}")
assert response.status_code == 404
# Clean up
if get_current_superuser in app.dependency_overrides:
del app.dependency_overrides[get_current_superuser]
def test_cannot_delete_self(self, client, app, super_user, db_session):
"""Test that users cannot delete their own account."""
app.dependency_overrides[get_current_superuser] = lambda: super_user
response = client.delete(f"/api/v1/users/{super_user.id}")
assert response.status_code == 403
# Clean up
if get_current_superuser in app.dependency_overrides:
del app.dependency_overrides[get_current_superuser]

View File

View File

@@ -0,0 +1,233 @@
# tests/utils/test_security.py
"""
Tests for security utility functions.
"""
import time
import base64
import json
import pytest
from unittest.mock import patch, MagicMock
from app.utils.security import create_upload_token, verify_upload_token
class TestCreateUploadToken:
"""Tests for create_upload_token function."""
def test_create_upload_token_basic(self):
"""Test basic token creation."""
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
assert token is not None
assert isinstance(token, str)
assert len(token) > 0
# Token should be base64 encoded
try:
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
token_data = json.loads(decoded)
assert "payload" in token_data
assert "signature" in token_data
except Exception as e:
pytest.fail(f"Token is not properly formatted: {e}")
def test_create_upload_token_contains_correct_payload(self):
"""Test that token contains correct payload data."""
file_path = "/uploads/avatar.jpg"
content_type = "image/jpeg"
token = create_upload_token(file_path, content_type)
# Decode and verify payload
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
token_data = json.loads(decoded)
payload = token_data["payload"]
assert payload["path"] == file_path
assert payload["content_type"] == content_type
assert "exp" in payload
assert "nonce" in payload
def test_create_upload_token_default_expiration(self):
"""Test that default expiration is 300 seconds (5 minutes)."""
before = int(time.time())
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
after = int(time.time())
# Decode token
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
token_data = json.loads(decoded)
payload = token_data["payload"]
# Expiration should be around current time + 300 seconds
exp_time = payload["exp"]
assert before + 300 <= exp_time <= after + 300
def test_create_upload_token_custom_expiration(self):
"""Test token creation with custom expiration time."""
custom_exp = 600 # 10 minutes
before = int(time.time())
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp)
after = int(time.time())
# Decode token
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
token_data = json.loads(decoded)
payload = token_data["payload"]
# Expiration should be around current time + custom_exp seconds
exp_time = payload["exp"]
assert before + custom_exp <= exp_time <= after + custom_exp
def test_create_upload_token_unique_nonces(self):
"""Test that each token has a unique nonce."""
token1 = create_upload_token("/uploads/test.jpg", "image/jpeg")
token2 = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode both tokens
decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8'))
token_data1 = json.loads(decoded1)
nonce1 = token_data1["payload"]["nonce"]
decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8'))
token_data2 = json.loads(decoded2)
nonce2 = token_data2["payload"]["nonce"]
# Nonces should be different
assert nonce1 != nonce2
def test_create_upload_token_different_paths(self):
"""Test that tokens for different paths are different."""
token1 = create_upload_token("/uploads/file1.jpg", "image/jpeg")
token2 = create_upload_token("/uploads/file2.jpg", "image/jpeg")
assert token1 != token2
class TestVerifyUploadToken:
"""Tests for verify_upload_token function."""
def test_verify_valid_token(self):
"""Test verification of a valid token."""
file_path = "/uploads/test.jpg"
content_type = "image/jpeg"
token = create_upload_token(file_path, content_type)
payload = verify_upload_token(token)
assert payload is not None
assert payload["path"] == file_path
assert payload["content_type"] == content_type
def test_verify_expired_token(self):
"""Test that expired tokens are rejected."""
# Create a mock time module
mock_time = MagicMock()
current_time = 1000000
mock_time.time = MagicMock(return_value=current_time)
with patch('app.utils.security.time', mock_time):
# Create token that "expires" at current_time + 1
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1)
# Now set time to after expiration
mock_time.time.return_value = current_time + 2
# Token should be expired
payload = verify_upload_token(token)
assert payload is None
def test_verify_invalid_signature(self):
"""Test that tokens with invalid signatures are rejected."""
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode, modify, and re-encode
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
token_data = json.loads(decoded)
token_data["signature"] = "invalid_signature"
# Re-encode the tampered token
tampered_json = json.dumps(token_data)
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
payload = verify_upload_token(tampered_token)
assert payload is None
def test_verify_tampered_payload(self):
"""Test that tokens with tampered payloads are rejected."""
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode, modify payload, and re-encode
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
token_data = json.loads(decoded)
token_data["payload"]["path"] = "/uploads/hacked.exe"
# Re-encode the tampered token (signature won't match)
tampered_json = json.dumps(token_data)
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
payload = verify_upload_token(tampered_token)
assert payload is None
def test_verify_malformed_token(self):
"""Test that malformed tokens return None."""
# Test various malformed tokens
invalid_tokens = [
"not_a_valid_token",
"SGVsbG8gV29ybGQ=", # Valid base64 but not a token
"",
" ",
]
for invalid_token in invalid_tokens:
payload = verify_upload_token(invalid_token)
assert payload is None
def test_verify_invalid_json(self):
"""Test that tokens with invalid JSON are rejected."""
# Create a base64 string that decodes to invalid JSON
invalid_json = "not valid json"
invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8')
payload = verify_upload_token(invalid_token)
assert payload is None
def test_verify_missing_fields(self):
"""Test that tokens missing required fields are rejected."""
# Create a token-like structure but missing required fields
incomplete_data = {
"payload": {
"path": "/uploads/test.jpg"
# Missing content_type, exp, nonce
},
"signature": "some_signature"
}
incomplete_json = json.dumps(incomplete_data)
incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8')
payload = verify_upload_token(incomplete_token)
assert payload is None
def test_verify_token_round_trip(self):
"""Test creating and verifying a token in sequence."""
test_cases = [
("/uploads/image.jpg", "image/jpeg", 300),
("/uploads/document.pdf", "application/pdf", 600),
("/uploads/video.mp4", "video/mp4", 900),
]
for file_path, content_type, expires_in in test_cases:
token = create_upload_token(file_path, content_type, expires_in)
payload = verify_upload_token(token)
assert payload is not None
assert payload["path"] == file_path
assert payload["content_type"] == content_type
assert "exp" in payload
assert "nonce" in payload
# Note: test_verify_token_cannot_be_reused_with_different_secret removed
# The signature validation is already tested by test_verify_invalid_signature
# and test_verify_tampered_payload. Testing with different SECRET_KEY
# requires complex mocking that can interfere with other tests.