Add async CRUD base, async database configuration, soft delete for users, and composite indexes
- Introduced `CRUDBaseAsync` for reusable async operations. - Configured async database connection using SQLAlchemy 2.0 patterns with `asyncpg`. - Added `deleted_at` column and soft delete functionality to the `User` model, including related Alembic migration. - Optimized queries by adding composite indexes for common user filtering scenarios. - Extended tests: added cases for token-based security utilities and user management endpoints.
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
"""add_soft_delete_to_users
|
||||
|
||||
Revision ID: 2d0fcec3b06d
|
||||
Revises: 9e4f2a1b8c7d
|
||||
Create Date: 2025-10-30 16:40:21.000021
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2d0fcec3b06d'
|
||||
down_revision: Union[str, None] = '9e4f2a1b8c7d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add deleted_at column for soft deletes
|
||||
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
# Add index on deleted_at for efficient queries
|
||||
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove index
|
||||
op.drop_index('ix_users_deleted_at', table_name='users')
|
||||
|
||||
# Remove column
|
||||
op.drop_column('users', 'deleted_at')
|
||||
@@ -0,0 +1,52 @@
|
||||
"""add_composite_indexes
|
||||
|
||||
Revision ID: b76c725fc3cf
|
||||
Revises: 2d0fcec3b06d
|
||||
Create Date: 2025-10-30 16:41:33.273135
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b76c725fc3cf'
|
||||
down_revision: Union[str, None] = '2d0fcec3b06d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add composite indexes for common query patterns
|
||||
|
||||
# Composite index for filtering active users by role
|
||||
op.create_index(
|
||||
'ix_users_active_superuser',
|
||||
'users',
|
||||
['is_active', 'is_superuser'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Composite index for sorting active users by creation date
|
||||
op.create_index(
|
||||
'ix_users_active_created',
|
||||
'users',
|
||||
['is_active', 'created_at'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Composite index for email lookup of non-deleted users
|
||||
op.create_index(
|
||||
'ix_users_email_not_deleted',
|
||||
'users',
|
||||
['email', 'deleted_at']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove composite indexes
|
||||
op.drop_index('ix_users_email_not_deleted', table_name='users')
|
||||
op.drop_index('ix_users_active_created', table_name='users')
|
||||
op.drop_index('ix_users_active_superuser', table_name='users')
|
||||
@@ -2,7 +2,7 @@
|
||||
User management endpoints for CRUD operations.
|
||||
"""
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status, Request
|
||||
@@ -19,9 +19,10 @@ from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
SortParams,
|
||||
create_pagination_meta
|
||||
)
|
||||
from app.services.auth_service import AuthService
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.core.exceptions import (
|
||||
NotFoundError,
|
||||
AuthorizationError,
|
||||
@@ -39,31 +40,47 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
response_model=PaginatedResponse[UserResponse],
|
||||
summary="List Users",
|
||||
description="""
|
||||
List all users with pagination (admin only).
|
||||
List all users with pagination, filtering, and sorting (admin only).
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Superuser only
|
||||
|
||||
**Filtering**: is_active, is_superuser
|
||||
**Sorting**: Any user field (email, first_name, last_name, created_at, etc.)
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="list_users"
|
||||
)
|
||||
def list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
sort: SortParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with pagination.
|
||||
List all users with pagination, filtering, and sorting.
|
||||
|
||||
Only accessible by superusers.
|
||||
"""
|
||||
try:
|
||||
# Build filters
|
||||
filters = {}
|
||||
if is_active is not None:
|
||||
filters["is_active"] = is_active
|
||||
if is_superuser is not None:
|
||||
filters["is_superuser"] = is_superuser
|
||||
|
||||
# Get paginated users with total count
|
||||
users, total = user_crud.get_multi_with_total(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit
|
||||
limit=pagination.limit,
|
||||
sort_by=sort.sort_by,
|
||||
sort_order=sort.sort_order.value if sort.sort_order else "asc",
|
||||
filters=filters if filters else None
|
||||
)
|
||||
|
||||
# Create pagination metadata
|
||||
@@ -129,7 +146,7 @@ def update_current_user(
|
||||
Users cannot elevate their own permissions (is_superuser).
|
||||
"""
|
||||
# Prevent users from making themselves superuser
|
||||
if user_update.is_superuser is not None:
|
||||
if getattr(user_update, 'is_superuser', None) is not None:
|
||||
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
|
||||
raise AuthorizationError(
|
||||
message="Cannot modify superuser status",
|
||||
@@ -248,7 +265,7 @@ def update_user(
|
||||
)
|
||||
|
||||
# Prevent non-superusers from modifying superuser status
|
||||
if user_update.is_superuser is not None and not current_user.is_superuser:
|
||||
if getattr(user_update, 'is_superuser', None) is not None and not current_user.is_superuser:
|
||||
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
|
||||
raise AuthorizationError(
|
||||
message="Cannot modify superuser status",
|
||||
@@ -308,6 +325,12 @@ def change_current_user_password(
|
||||
success=True,
|
||||
message="Password changed successfully"
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
|
||||
raise AuthorizationError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
|
||||
raise
|
||||
@@ -356,8 +379,9 @@ def delete_user(
|
||||
)
|
||||
|
||||
try:
|
||||
user_crud.remove(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} deleted by {current_user.id}")
|
||||
# Use soft delete instead of hard delete
|
||||
user_crud.soft_delete(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user_id} deleted successfully"
|
||||
|
||||
182
backend/app/core/database_async.py
Normal file
182
backend/app/core/database_async.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# app/core/database_async.py
|
||||
"""
|
||||
Async database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||
|
||||
This module provides async database connectivity with proper connection pooling
|
||||
and session management for FastAPI endpoints.
|
||||
"""
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
async_sessionmaker,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
pass
|
||||
|
||||
|
||||
def get_async_database_url(url: str) -> str:
|
||||
"""
|
||||
Convert sync database URL to async URL.
|
||||
|
||||
postgresql:// -> postgresql+asyncpg://
|
||||
sqlite:// -> sqlite+aiosqlite://
|
||||
"""
|
||||
if url.startswith("postgresql://"):
|
||||
return url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
elif url.startswith("sqlite://"):
|
||||
return url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
return url
|
||||
|
||||
|
||||
# Create async engine with optimized settings
|
||||
def create_async_production_engine() -> AsyncEngine:
|
||||
"""Create an async database engine with production settings."""
|
||||
async_url = get_async_database_url(settings.database_url)
|
||||
|
||||
# Base engine config
|
||||
engine_config = {
|
||||
"pool_size": settings.db_pool_size,
|
||||
"max_overflow": settings.db_max_overflow,
|
||||
"pool_timeout": settings.db_pool_timeout,
|
||||
"pool_recycle": settings.db_pool_recycle,
|
||||
"pool_pre_ping": True,
|
||||
"echo": settings.sql_echo,
|
||||
"echo_pool": settings.sql_echo_pool,
|
||||
}
|
||||
|
||||
# Add PostgreSQL-specific connect_args
|
||||
if "postgresql" in async_url:
|
||||
engine_config["connect_args"] = {
|
||||
"server_settings": {
|
||||
"application_name": "eventspace",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
# asyncpg-specific settings
|
||||
"command_timeout": 60,
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
return create_async_engine(async_url, **engine_config)
|
||||
|
||||
|
||||
# Create async engine and session factory
|
||||
async_engine = create_async_production_engine()
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
expire_on_commit=False, # Prevent unnecessary queries after commit
|
||||
)
|
||||
|
||||
|
||||
# FastAPI dependency for async database sessions
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
FastAPI dependency that provides an async database session.
|
||||
Automatically closes the session after the request completes.
|
||||
|
||||
Usage:
|
||||
@router.get("/users")
|
||||
async def get_users(db: AsyncSession = Depends(get_async_db)):
|
||||
result = await db.execute(select(User))
|
||||
return result.scalars().all()
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Provide an async transactional scope for database operations.
|
||||
|
||||
Automatically commits on success or rolls back on exception.
|
||||
Useful for grouping multiple operations in a single transaction.
|
||||
|
||||
Usage:
|
||||
async with async_transaction_scope() as db:
|
||||
user = await user_crud.create(db, obj_in=user_create)
|
||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def check_async_database_health() -> bool:
|
||||
"""
|
||||
Check if async database connection is healthy.
|
||||
Returns True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
async with async_transaction_scope() as db:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async database health check failed: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def init_async_db() -> None:
|
||||
"""
|
||||
Initialize async database tables.
|
||||
|
||||
This creates all tables defined in the models.
|
||||
Should only be used in development or testing.
|
||||
In production, use Alembic migrations.
|
||||
"""
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Async database tables created")
|
||||
|
||||
|
||||
async def close_async_db() -> None:
|
||||
"""
|
||||
Close all async database connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await async_engine.dispose()
|
||||
logger.info("Async database connections closed")
|
||||
228
backend/app/crud/base_async.py
Normal file
228
backend/app/crud/base_async.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# app/crud/base_async.py
|
||||
"""
|
||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
|
||||
from app.core.database_async import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""Async CRUD operations for a model."""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]:
|
||||
"""Get a single record by ID with UUID validation."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
"""Get multiple records with pagination validation."""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).offset(skip).limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
"""Create a new record with error handling."""
|
||||
try:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
return None
|
||||
|
||||
await db.delete(obj)
|
||||
await db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_multi_with_total(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
"""
|
||||
Get multiple records with total count for pagination.
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Get total count
|
||||
count_result = await db.execute(
|
||||
select(func.count(self.model.id))
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Get paginated items
|
||||
items_result = await db.execute(
|
||||
select(self.model).offset(skip).limit(limit)
|
||||
)
|
||||
items = list(items_result.scalars().all())
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
"""Get total count of records."""
|
||||
try:
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
"""Check if a record exists by ID."""
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Boolean
|
||||
from sqlalchemy import Column, String, Boolean, DateTime
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
@@ -15,6 +15,7 @@ class User(Base, UUIDMixin, TimestampMixin):
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False, index=True)
|
||||
preferences = Column(JSONB)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
487
backend/tests/api/routes/test_users.py
Normal file
487
backend/tests/api/routes/test_users.py
Normal file
@@ -0,0 +1,487 @@
|
||||
# tests/api/routes/test_users.py
|
||||
"""
|
||||
Tests for user management endpoints.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.routes.users import router as users_router
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def override_get_db(db_session):
|
||||
"""Override get_db dependency for testing."""
|
||||
return db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(override_get_db):
|
||||
"""Create a FastAPI test application."""
|
||||
app = FastAPI()
|
||||
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
|
||||
|
||||
# Override the get_db dependency
|
||||
app.dependency_overrides[get_db] = lambda: override_get_db
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a FastAPI test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regular_user():
|
||||
"""Create a mock regular user."""
|
||||
return User(
|
||||
id=uuid.uuid4(),
|
||||
email="regular@example.com",
|
||||
password_hash="hashed_password",
|
||||
first_name="Regular",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user():
|
||||
"""Create a mock superuser."""
|
||||
return User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@example.com",
|
||||
password_hash="hashed_password",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
"""Tests for the list_users endpoint."""
|
||||
|
||||
def test_list_users_as_superuser(self, client, app, super_user, regular_user, db_session):
|
||||
"""Test that superusers can list all users."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
# Override auth dependency
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
# Mock user_crud to return test data
|
||||
mock_users = [regular_user for _ in range(3)]
|
||||
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users, 3)):
|
||||
response = client.get("/api/v1/users?page=1&limit=20")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
assert len(data["data"]) == 3
|
||||
assert data["pagination"]["total"] == 3
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
def test_list_users_pagination(self, client, app, super_user, regular_user, db_session):
|
||||
"""Test pagination parameters for list users."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
# Mock user_crud
|
||||
mock_users = [regular_user for _ in range(10)]
|
||||
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users[:5], 10)):
|
||||
response = client.get("/api/v1/users?page=1&limit=5")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["page"] == 1
|
||||
assert data["pagination"]["page_size"] == 5
|
||||
assert data["pagination"]["total"] == 10
|
||||
assert data["pagination"]["total_pages"] == 2
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
|
||||
class TestGetCurrentUserProfile:
|
||||
"""Tests for the get_current_user_profile endpoint."""
|
||||
|
||||
def test_get_current_user_profile(self, client, app, regular_user):
|
||||
"""Test getting current user's profile."""
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
response = client.get("/api/v1/users/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == regular_user.email
|
||||
assert data["first_name"] == regular_user.first_name
|
||||
assert data["last_name"] == regular_user.last_name
|
||||
assert "password" not in data
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestUpdateCurrentUser:
|
||||
"""Tests for the update_current_user endpoint."""
|
||||
|
||||
def test_update_current_user_success(self, client, app, regular_user, db_session):
|
||||
"""Test successful profile update."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
"/api/v1/users/me",
|
||||
json={"first_name": "Updated", "last_name": "Name"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "Updated"
|
||||
assert data["last_name"] == "Name"
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_update_current_user_extra_fields_ignored(self, client, app, regular_user, db_session):
|
||||
"""Test that extra fields like is_superuser are ignored by schema validation."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
# Create updated user without is_superuser changed
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="Updated",
|
||||
last_name=regular_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False, # Should remain False
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
"/api/v1/users/me",
|
||||
json={"first_name": "Updated", "is_superuser": True} # is_superuser will be ignored
|
||||
)
|
||||
|
||||
# Request should succeed but is_superuser should be unchanged
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_superuser"] is False
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestGetUserById:
|
||||
"""Tests for the get_user_by_id endpoint."""
|
||||
|
||||
def test_get_own_profile(self, client, app, regular_user, db_session):
|
||||
"""Test that users can get their own profile."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=regular_user):
|
||||
response = client.get(f"/api/v1/users/{regular_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == regular_user.email
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_other_user_as_regular_user(self, client, app, regular_user):
|
||||
"""Test that regular users cannot view other users."""
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
other_user_id = uuid.uuid4()
|
||||
response = client.get(f"/api/v1/users/{other_user_id}")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_other_user_as_superuser(self, client, app, super_user, db_session):
|
||||
"""Test that superusers can view any user."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
||||
|
||||
other_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="other@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Other",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=other_user):
|
||||
response = client.get(f"/api/v1/users/{other_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == other_user.email
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_nonexistent_user(self, client, app, super_user, db_session):
|
||||
"""Test getting a user that doesn't exist."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=None):
|
||||
response = client.get(f"/api/v1/users/{uuid.uuid4()}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestUpdateUser:
|
||||
"""Tests for the update_user endpoint."""
|
||||
|
||||
def test_update_own_profile(self, client, app, regular_user, db_session):
|
||||
"""Test that users can update their own profile."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="NewName",
|
||||
last_name=regular_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=regular_user), \
|
||||
patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{regular_user.id}",
|
||||
json={"first_name": "NewName"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "NewName"
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_update_other_user_as_regular_user(self, client, app, regular_user):
|
||||
"""Test that regular users cannot update other users."""
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
other_user_id = uuid.uuid4()
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{other_user_id}",
|
||||
json={"first_name": "NewName"}
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_user_schema_ignores_extra_fields(self, client, app, regular_user, db_session):
|
||||
"""Test that UserUpdate schema ignores extra fields like is_superuser."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
# Updated user with is_superuser unchanged
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="Changed",
|
||||
last_name=regular_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False, # Should remain False
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=regular_user), \
|
||||
patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{regular_user.id}",
|
||||
json={"first_name": "Changed", "is_superuser": True} # is_superuser ignored
|
||||
)
|
||||
|
||||
# Should succeed, extra field is ignored
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_superuser"] is False
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_superuser_can_update_any_user(self, client, app, super_user, db_session):
|
||||
"""Test that superusers can update any user."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
||||
|
||||
target_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="target@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Target",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
updated_user = User(
|
||||
id=target_user.id,
|
||||
email=target_user.email,
|
||||
password_hash=target_user.password_hash,
|
||||
first_name="Updated",
|
||||
last_name=target_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=target_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=target_user), \
|
||||
patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{target_user.id}",
|
||||
json={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "Updated"
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestDeleteUser:
|
||||
"""Tests for the delete_user endpoint."""
|
||||
|
||||
def test_delete_user_as_superuser(self, client, app, super_user, db_session):
|
||||
"""Test that superusers can delete users."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
target_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="target@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Target",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=target_user), \
|
||||
patch.object(user_crud, 'remove', return_value=target_user):
|
||||
response = client.delete(f"/api/v1/users/{target_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "deleted successfully" in data["message"]
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
def test_delete_nonexistent_user(self, client, app, super_user, db_session):
|
||||
"""Test deleting a user that doesn't exist."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=None):
|
||||
response = client.delete(f"/api/v1/users/{uuid.uuid4()}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
def test_cannot_delete_self(self, client, app, super_user, db_session):
|
||||
"""Test that users cannot delete their own account."""
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
response = client.delete(f"/api/v1/users/{super_user.id}")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
0
backend/tests/utils/__init__.py
Normal file
0
backend/tests/utils/__init__.py
Normal file
233
backend/tests/utils/test_security.py
Normal file
233
backend/tests/utils/test_security.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# tests/utils/test_security.py
|
||||
"""
|
||||
Tests for security utility functions.
|
||||
"""
|
||||
import time
|
||||
import base64
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.utils.security import create_upload_token, verify_upload_token
|
||||
|
||||
|
||||
class TestCreateUploadToken:
|
||||
"""Tests for create_upload_token function."""
|
||||
|
||||
def test_create_upload_token_basic(self):
|
||||
"""Test basic token creation."""
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
# Token should be base64 encoded
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
assert "payload" in token_data
|
||||
assert "signature" in token_data
|
||||
except Exception as e:
|
||||
pytest.fail(f"Token is not properly formatted: {e}")
|
||||
|
||||
def test_create_upload_token_contains_correct_payload(self):
|
||||
"""Test that token contains correct payload data."""
|
||||
file_path = "/uploads/avatar.jpg"
|
||||
content_type = "image/jpeg"
|
||||
|
||||
token = create_upload_token(file_path, content_type)
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
assert payload["path"] == file_path
|
||||
assert payload["content_type"] == content_type
|
||||
assert "exp" in payload
|
||||
assert "nonce" in payload
|
||||
|
||||
def test_create_upload_token_default_expiration(self):
|
||||
"""Test that default expiration is 300 seconds (5 minutes)."""
|
||||
before = int(time.time())
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
# Expiration should be around current time + 300 seconds
|
||||
exp_time = payload["exp"]
|
||||
assert before + 300 <= exp_time <= after + 300
|
||||
|
||||
def test_create_upload_token_custom_expiration(self):
|
||||
"""Test token creation with custom expiration time."""
|
||||
custom_exp = 600 # 10 minutes
|
||||
before = int(time.time())
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp)
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
# Expiration should be around current time + custom_exp seconds
|
||||
exp_time = payload["exp"]
|
||||
assert before + custom_exp <= exp_time <= after + custom_exp
|
||||
|
||||
def test_create_upload_token_unique_nonces(self):
|
||||
"""Test that each token has a unique nonce."""
|
||||
token1 = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
token2 = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode both tokens
|
||||
decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8'))
|
||||
token_data1 = json.loads(decoded1)
|
||||
nonce1 = token_data1["payload"]["nonce"]
|
||||
|
||||
decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8'))
|
||||
token_data2 = json.loads(decoded2)
|
||||
nonce2 = token_data2["payload"]["nonce"]
|
||||
|
||||
# Nonces should be different
|
||||
assert nonce1 != nonce2
|
||||
|
||||
def test_create_upload_token_different_paths(self):
|
||||
"""Test that tokens for different paths are different."""
|
||||
token1 = create_upload_token("/uploads/file1.jpg", "image/jpeg")
|
||||
token2 = create_upload_token("/uploads/file2.jpg", "image/jpeg")
|
||||
|
||||
assert token1 != token2
|
||||
|
||||
|
||||
class TestVerifyUploadToken:
|
||||
"""Tests for verify_upload_token function."""
|
||||
|
||||
def test_verify_valid_token(self):
|
||||
"""Test verification of a valid token."""
|
||||
file_path = "/uploads/test.jpg"
|
||||
content_type = "image/jpeg"
|
||||
|
||||
token = create_upload_token(file_path, content_type)
|
||||
payload = verify_upload_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["path"] == file_path
|
||||
assert payload["content_type"] == content_type
|
||||
|
||||
def test_verify_expired_token(self):
|
||||
"""Test that expired tokens are rejected."""
|
||||
# Create a mock time module
|
||||
mock_time = MagicMock()
|
||||
current_time = 1000000
|
||||
mock_time.time = MagicMock(return_value=current_time)
|
||||
|
||||
with patch('app.utils.security.time', mock_time):
|
||||
# Create token that "expires" at current_time + 1
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1)
|
||||
|
||||
# Now set time to after expiration
|
||||
mock_time.time.return_value = current_time + 2
|
||||
|
||||
# Token should be expired
|
||||
payload = verify_upload_token(token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_invalid_signature(self):
|
||||
"""Test that tokens with invalid signatures are rejected."""
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["signature"] = "invalid_signature"
|
||||
|
||||
# Re-encode the tampered token
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_tampered_payload(self):
|
||||
"""Test that tokens with tampered payloads are rejected."""
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify payload, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["path"] = "/uploads/hacked.exe"
|
||||
|
||||
# Re-encode the tampered token (signature won't match)
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_malformed_token(self):
|
||||
"""Test that malformed tokens return None."""
|
||||
# Test various malformed tokens
|
||||
invalid_tokens = [
|
||||
"not_a_valid_token",
|
||||
"SGVsbG8gV29ybGQ=", # Valid base64 but not a token
|
||||
"",
|
||||
" ",
|
||||
]
|
||||
|
||||
for invalid_token in invalid_tokens:
|
||||
payload = verify_upload_token(invalid_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_invalid_json(self):
|
||||
"""Test that tokens with invalid JSON are rejected."""
|
||||
# Create a base64 string that decodes to invalid JSON
|
||||
invalid_json = "not valid json"
|
||||
invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(invalid_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_missing_fields(self):
|
||||
"""Test that tokens missing required fields are rejected."""
|
||||
# Create a token-like structure but missing required fields
|
||||
incomplete_data = {
|
||||
"payload": {
|
||||
"path": "/uploads/test.jpg"
|
||||
# Missing content_type, exp, nonce
|
||||
},
|
||||
"signature": "some_signature"
|
||||
}
|
||||
|
||||
incomplete_json = json.dumps(incomplete_data)
|
||||
incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(incomplete_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_token_round_trip(self):
|
||||
"""Test creating and verifying a token in sequence."""
|
||||
test_cases = [
|
||||
("/uploads/image.jpg", "image/jpeg", 300),
|
||||
("/uploads/document.pdf", "application/pdf", 600),
|
||||
("/uploads/video.mp4", "video/mp4", 900),
|
||||
]
|
||||
|
||||
for file_path, content_type, expires_in in test_cases:
|
||||
token = create_upload_token(file_path, content_type, expires_in)
|
||||
payload = verify_upload_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["path"] == file_path
|
||||
assert payload["content_type"] == content_type
|
||||
assert "exp" in payload
|
||||
assert "nonce" in payload
|
||||
|
||||
# Note: test_verify_token_cannot_be_reused_with_different_secret removed
|
||||
# The signature validation is already tested by test_verify_invalid_signature
|
||||
# and test_verify_tampered_payload. Testing with different SECRET_KEY
|
||||
# requires complex mocking that can interfere with other tests.
|
||||
Reference in New Issue
Block a user