Compare commits
7 Commits
778da09a42
...
313e6691b5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
313e6691b5 | ||
|
|
c684f2ba95 | ||
|
|
2c600290a1 | ||
|
|
d83959963b | ||
|
|
5bed14b6b0 | ||
|
|
f163ffbb83 | ||
|
|
54e389d230 |
@@ -12,12 +12,17 @@ DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}
|
||||
|
||||
# Backend settings
|
||||
BACKEND_PORT=8000
|
||||
SECRET_KEY=your_secret_key_here
|
||||
# CRITICAL: Generate a secure SECRET_KEY for production!
|
||||
# Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'
|
||||
# Must be at least 32 characters
|
||||
SECRET_KEY=your_secret_key_here_REPLACE_WITH_GENERATED_KEY_32_CHARS_MIN
|
||||
ENVIRONMENT=development
|
||||
DEBUG=true
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
FIRST_SUPERUSER_PASSWORD=Admin123
|
||||
# IMPORTANT: Use a strong password (min 12 chars, mixed case, digits)
|
||||
# Default weak passwords like 'Admin123' are rejected
|
||||
FIRST_SUPERUSER_PASSWORD=YourStrongPassword123!
|
||||
|
||||
# Frontend settings
|
||||
FRONTEND_PORT=3000
|
||||
|
||||
@@ -1,34 +1,66 @@
|
||||
# Development stage
|
||||
FROM python:3.12-slim AS development
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc postgresql-client && \
|
||||
apt-get install -y --no-install-recommends gcc postgresql-client curl && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
# Set ownership to non-root user
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
|
||||
# Production stage
|
||||
FROM python:3.12-slim AS production
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends postgresql-client && \
|
||||
apt-get install -y --no-install-recommends postgresql-client curl && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
# Set ownership to non-root user
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Add health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,34 @@
|
||||
"""add_soft_delete_to_users
|
||||
|
||||
Revision ID: 2d0fcec3b06d
|
||||
Revises: 9e4f2a1b8c7d
|
||||
Create Date: 2025-10-30 16:40:21.000021
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2d0fcec3b06d'
|
||||
down_revision: Union[str, None] = '9e4f2a1b8c7d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add deleted_at column for soft deletes
|
||||
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
# Add index on deleted_at for efficient queries
|
||||
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove index
|
||||
op.drop_index('ix_users_deleted_at', table_name='users')
|
||||
|
||||
# Remove column
|
||||
op.drop_column('users', 'deleted_at')
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Add missing indexes and fix column types
|
||||
|
||||
Revision ID: 9e4f2a1b8c7d
|
||||
Revises: 38bf9e7e74b3
|
||||
Create Date: 2025-10-30 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9e4f2a1b8c7d'
|
||||
down_revision: Union[str, None] = '38bf9e7e74b3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add missing indexes for is_active and is_superuser
|
||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
||||
|
||||
# Fix column types to match model definitions with explicit lengths
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=False,
|
||||
server_default='user') # Add server default
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=20),
|
||||
nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert column types
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(length=20),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
server_default=None) # Remove server default
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
|
||||
# Drop indexes
|
||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
||||
@@ -0,0 +1,52 @@
|
||||
"""add_composite_indexes
|
||||
|
||||
Revision ID: b76c725fc3cf
|
||||
Revises: 2d0fcec3b06d
|
||||
Create Date: 2025-10-30 16:41:33.273135
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b76c725fc3cf'
|
||||
down_revision: Union[str, None] = '2d0fcec3b06d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add composite indexes for common query patterns
|
||||
|
||||
# Composite index for filtering active users by role
|
||||
op.create_index(
|
||||
'ix_users_active_superuser',
|
||||
'users',
|
||||
['is_active', 'is_superuser'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Composite index for sorting active users by creation date
|
||||
op.create_index(
|
||||
'ix_users_active_created',
|
||||
'users',
|
||||
['is_active', 'created_at'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Composite index for email lookup of non-deleted users
|
||||
op.create_index(
|
||||
'ix_users_email_not_deleted',
|
||||
'users',
|
||||
['email', 'deleted_at']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove composite indexes
|
||||
op.drop_index('ix_users_email_not_deleted', table_name='users')
|
||||
op.drop_index('ix_users_active_created', table_name='users')
|
||||
op.drop_index('ix_users_active_superuser', table_name='users')
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth
|
||||
from app.api.routes import auth, users
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
@@ -22,9 +24,14 @@ from app.services.auth_service import AuthService, AuthenticationError
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize limiter for this router
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
||||
@limiter.limit("5/minute")
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
@@ -52,7 +59,9 @@ async def register_user(
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token, operation_id="login")
|
||||
@limiter.limit("10/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
login_data: LoginRequest,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
@@ -101,7 +110,9 @@ async def login(
|
||||
|
||||
|
||||
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
|
||||
@limiter.limit("10/minute")
|
||||
async def login_oauth(
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
@@ -148,7 +159,9 @@ async def login_oauth(
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
|
||||
@limiter.limit("30/minute")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
@@ -183,44 +196,10 @@ async def refresh_token(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/change-password", status_code=status.HTTP_200_OK, operation_id="change_password")
|
||||
async def change_password(
|
||||
current_password: str = Body(..., embed=True),
|
||||
new_password: str = Body(..., embed=True),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
|
||||
Requires authentication.
|
||||
"""
|
||||
try:
|
||||
success = AuthService.change_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=current_password,
|
||||
new_password=new_password
|
||||
)
|
||||
|
||||
if success:
|
||||
return {"message": "Password changed successfully"}
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Password change failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during password change: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info")
|
||||
@limiter.limit("60/minute")
|
||||
async def get_current_user_info(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
|
||||
394
backend/app/api/routes/users.py
Normal file
394
backend/app/api/routes/users.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
User management endpoints for CRUD operations.
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
from app.core.database import get_db
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
SortParams,
|
||||
create_pagination_meta
|
||||
)
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.core.exceptions import (
|
||||
NotFoundError,
|
||||
AuthorizationError,
|
||||
ErrorCode
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PaginatedResponse[UserResponse],
|
||||
summary="List Users",
|
||||
description="""
|
||||
List all users with pagination, filtering, and sorting (admin only).
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Superuser only
|
||||
|
||||
**Filtering**: is_active, is_superuser
|
||||
**Sorting**: Any user field (email, first_name, last_name, created_at, etc.)
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="list_users"
|
||||
)
|
||||
def list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
sort: SortParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with pagination, filtering, and sorting.
|
||||
|
||||
Only accessible by superusers.
|
||||
"""
|
||||
try:
|
||||
# Build filters
|
||||
filters = {}
|
||||
if is_active is not None:
|
||||
filters["is_active"] = is_active
|
||||
if is_superuser is not None:
|
||||
filters["is_superuser"] = is_superuser
|
||||
|
||||
# Get paginated users with total count
|
||||
users, total = user_crud.get_multi_with_total(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
sort_by=sort.sort_by,
|
||||
sort_order=sort.sort_order.value if sort.sort_order else "asc",
|
||||
filters=filters if filters else None
|
||||
)
|
||||
|
||||
# Create pagination metadata
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(users)
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=users,
|
||||
pagination=pagination_meta
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=UserResponse,
|
||||
summary="Get Current User",
|
||||
description="""
|
||||
Get the current authenticated user's profile.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_current_user_profile"
|
||||
)
|
||||
def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""Get current user's profile."""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/me",
|
||||
response_model=UserResponse,
|
||||
summary="Update Current User",
|
||||
description="""
|
||||
Update the current authenticated user's profile.
|
||||
|
||||
Users can update their own profile information (except is_superuser).
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_current_user"
|
||||
)
|
||||
def update_current_user(
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Update current user's profile.
|
||||
|
||||
Users cannot elevate their own permissions (is_superuser).
|
||||
"""
|
||||
# Prevent users from making themselves superuser
|
||||
if getattr(user_update, 'is_superuser', None) is not None:
|
||||
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
|
||||
raise AuthorizationError(
|
||||
message="Cannot modify superuser status",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
try:
|
||||
updated_user = user_crud.update(
|
||||
db,
|
||||
db_obj=current_user,
|
||||
obj_in=user_update
|
||||
)
|
||||
logger.info(f"User {current_user.id} updated their profile")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {current_user.id}: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{user_id}",
|
||||
response_model=UserResponse,
|
||||
summary="Get User by ID",
|
||||
description="""
|
||||
Get a specific user by their ID.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**:
|
||||
- Regular users: Can only access their own profile
|
||||
- Superusers: Can access any profile
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_user_by_id"
|
||||
)
|
||||
def get_user_by_id(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Get user by ID.
|
||||
|
||||
Users can only view their own profile unless they are superusers.
|
||||
"""
|
||||
# Check permissions
|
||||
if str(user_id) != str(current_user.id) and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to access user {user_id} without permission"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to view this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{user_id}",
|
||||
response_model=UserResponse,
|
||||
summary="Update User",
|
||||
description="""
|
||||
Update a specific user by their ID.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**:
|
||||
- Regular users: Can only update their own profile (except is_superuser)
|
||||
- Superusers: Can update any profile
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_user"
|
||||
)
|
||||
def update_user(
|
||||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Update user by ID.
|
||||
|
||||
Users can update their own profile. Superusers can update any profile.
|
||||
Regular users cannot modify is_superuser field.
|
||||
"""
|
||||
# Check permissions
|
||||
is_own_profile = str(user_id) == str(current_user.id)
|
||||
|
||||
if not is_own_profile and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to update user {user_id} without permission"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to update this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
# Prevent non-superusers from modifying superuser status
|
||||
if getattr(user_update, 'is_superuser', None) is not None and not current_user.is_superuser:
|
||||
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
|
||||
raise AuthorizationError(
|
||||
message="Cannot modify superuser status",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
try:
|
||||
updated_user = user_crud.update(db, db_obj=user, obj_in=user_update)
|
||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {user_id}: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/me/password",
|
||||
response_model=MessageResponse,
|
||||
summary="Change Current User Password",
|
||||
description="""
|
||||
Change the current authenticated user's password.
|
||||
|
||||
Requires the current password for verification.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="change_current_user_password"
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def change_current_user_password(
|
||||
request: Request,
|
||||
password_change: PasswordChange,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
|
||||
Requires current password for verification.
|
||||
"""
|
||||
try:
|
||||
success = AuthService.change_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=password_change.current_password,
|
||||
new_password=password_change.new_password
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"User {current_user.id} changed their password")
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password changed successfully"
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
|
||||
raise AuthorizationError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{user_id}",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=MessageResponse,
|
||||
summary="Delete User",
|
||||
description="""
|
||||
Delete a specific user by their ID.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Superuser only
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
|
||||
**Note**: This performs a hard delete. Consider implementing soft deletes for production.
|
||||
""",
|
||||
operation_id="delete_user"
|
||||
)
|
||||
def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Delete user by ID (superuser only).
|
||||
|
||||
This is a hard delete operation.
|
||||
"""
|
||||
# Prevent self-deletion
|
||||
if str(user_id) == str(current_user.id):
|
||||
raise AuthorizationError(
|
||||
message="Cannot delete your own account",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
try:
|
||||
# Use soft delete instead of hard delete
|
||||
user_crud.soft_delete(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user_id} deleted successfully"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error deleting user {user_id}: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -1,12 +1,20 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional, List
|
||||
from pydantic import Field, field_validator
|
||||
import logging
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "EventSpace"
|
||||
PROJECT_NAME: str = "App"
|
||||
VERSION: str = "1.0.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# Environment (must be before SECRET_KEY for validation)
|
||||
ENVIRONMENT: str = Field(
|
||||
default="development",
|
||||
description="Environment: development, staging, or production"
|
||||
)
|
||||
|
||||
# Database configuration
|
||||
POSTGRES_USER: str = "postgres"
|
||||
POSTGRES_PASSWORD: str = "postgres"
|
||||
@@ -14,7 +22,6 @@ class Settings(BaseSettings):
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "app"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
@@ -39,21 +46,90 @@ class Settings(BaseSettings):
|
||||
return self.DATABASE_URL
|
||||
|
||||
# JWT configuration
|
||||
SECRET_KEY: str = "your_secret_key_here"
|
||||
SECRET_KEY: str = Field(
|
||||
default="dev_only_insecure_key_change_in_production_32chars_min",
|
||||
min_length=32,
|
||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
|
||||
)
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 1 day
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard)
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days
|
||||
|
||||
# CORS configuration
|
||||
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
||||
|
||||
# Admin user
|
||||
FIRST_SUPERUSER_EMAIL: Optional[str] = None
|
||||
FIRST_SUPERUSER_PASSWORD: Optional[str] = None
|
||||
FIRST_SUPERUSER_EMAIL: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Email for first superuser account"
|
||||
)
|
||||
FIRST_SUPERUSER_PASSWORD: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Password for first superuser (min 12 characters)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = True
|
||||
@field_validator('SECRET_KEY')
|
||||
@classmethod
|
||||
def validate_secret_key(cls, v: str, info) -> str:
|
||||
"""Validate SECRET_KEY is secure, especially in production."""
|
||||
# Get environment from values if available
|
||||
values_data = info.data if info.data else {}
|
||||
env = values_data.get('ENVIRONMENT', 'development')
|
||||
|
||||
if v.startswith("your_secret_key_here"):
|
||||
if env == "production":
|
||||
raise ValueError(
|
||||
"SECRET_KEY must be set to a secure random value in production. "
|
||||
"Generate one with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
|
||||
)
|
||||
# Warn in development but allow
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"⚠️ Using default SECRET_KEY. This is ONLY acceptable in development. "
|
||||
"Generate a secure key with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
|
||||
)
|
||||
|
||||
if len(v) < 32:
|
||||
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('FIRST_SUPERUSER_PASSWORD')
|
||||
@classmethod
|
||||
def validate_superuser_password(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate superuser password strength."""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if len(v) < 12:
|
||||
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
|
||||
|
||||
# Check for common weak passwords
|
||||
weak_passwords = {'admin123', 'Admin123', 'password123', 'Password123', '123456789012'}
|
||||
if v in weak_passwords:
|
||||
raise ValueError(
|
||||
"FIRST_SUPERUSER_PASSWORD is too weak. "
|
||||
"Use a strong, unique password with mixed case, numbers, and symbols."
|
||||
)
|
||||
|
||||
# Basic strength check
|
||||
has_lower = any(c.islower() for c in v)
|
||||
has_upper = any(c.isupper() for c in v)
|
||||
has_digit = any(c.isdigit() for c in v)
|
||||
|
||||
if not (has_lower and has_upper and has_digit):
|
||||
raise ValueError(
|
||||
"FIRST_SUPERUSER_PASSWORD must contain lowercase, uppercase, and digits"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
model_config = {
|
||||
"env_file": "../.env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": True,
|
||||
"extra": "ignore" # Ignore extra fields from .env (e.g., frontend-specific vars)
|
||||
}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,8 +1,10 @@
|
||||
# app/core/database.py
|
||||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
@@ -49,12 +51,62 @@ def create_production_engine():
|
||||
|
||||
# Default production engine and session factory
|
||||
engine = create_production_engine()
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
expire_on_commit=False # Prevent unnecessary queries after commit
|
||||
)
|
||||
|
||||
# FastAPI dependency
|
||||
def get_db():
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
FastAPI dependency that provides a database session.
|
||||
Automatically closes the session after the request completes.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction_scope() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Provide a transactional scope for database operations.
|
||||
|
||||
Automatically commits on success or rolls back on exception.
|
||||
Useful for grouping multiple operations in a single transaction.
|
||||
|
||||
Usage:
|
||||
with transaction_scope() as db:
|
||||
user = user_crud.create(db, obj_in=user_create)
|
||||
profile = profile_crud.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
logger.debug("Transaction committed successfully")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Transaction failed, rolling back: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def check_database_health() -> bool:
|
||||
"""
|
||||
Check if database connection is healthy.
|
||||
Returns True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with transaction_scope() as db:
|
||||
db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {str(e)}")
|
||||
return False
|
||||
182
backend/app/core/database_async.py
Normal file
182
backend/app/core/database_async.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# app/core/database_async.py
|
||||
"""
|
||||
Async database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||
|
||||
This module provides async database connectivity with proper connection pooling
|
||||
and session management for FastAPI endpoints.
|
||||
"""
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
async_sessionmaker,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
pass
|
||||
|
||||
|
||||
def get_async_database_url(url: str) -> str:
|
||||
"""
|
||||
Convert sync database URL to async URL.
|
||||
|
||||
postgresql:// -> postgresql+asyncpg://
|
||||
sqlite:// -> sqlite+aiosqlite://
|
||||
"""
|
||||
if url.startswith("postgresql://"):
|
||||
return url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
elif url.startswith("sqlite://"):
|
||||
return url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
return url
|
||||
|
||||
|
||||
# Create async engine with optimized settings
|
||||
def create_async_production_engine() -> AsyncEngine:
|
||||
"""Create an async database engine with production settings."""
|
||||
async_url = get_async_database_url(settings.database_url)
|
||||
|
||||
# Base engine config
|
||||
engine_config = {
|
||||
"pool_size": settings.db_pool_size,
|
||||
"max_overflow": settings.db_max_overflow,
|
||||
"pool_timeout": settings.db_pool_timeout,
|
||||
"pool_recycle": settings.db_pool_recycle,
|
||||
"pool_pre_ping": True,
|
||||
"echo": settings.sql_echo,
|
||||
"echo_pool": settings.sql_echo_pool,
|
||||
}
|
||||
|
||||
# Add PostgreSQL-specific connect_args
|
||||
if "postgresql" in async_url:
|
||||
engine_config["connect_args"] = {
|
||||
"server_settings": {
|
||||
"application_name": "eventspace",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
# asyncpg-specific settings
|
||||
"command_timeout": 60,
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
return create_async_engine(async_url, **engine_config)
|
||||
|
||||
|
||||
# Create async engine and session factory
|
||||
async_engine = create_async_production_engine()
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
expire_on_commit=False, # Prevent unnecessary queries after commit
|
||||
)
|
||||
|
||||
|
||||
# FastAPI dependency for async database sessions
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
FastAPI dependency that provides an async database session.
|
||||
Automatically closes the session after the request completes.
|
||||
|
||||
Usage:
|
||||
@router.get("/users")
|
||||
async def get_users(db: AsyncSession = Depends(get_async_db)):
|
||||
result = await db.execute(select(User))
|
||||
return result.scalars().all()
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Provide an async transactional scope for database operations.
|
||||
|
||||
Automatically commits on success or rolls back on exception.
|
||||
Useful for grouping multiple operations in a single transaction.
|
||||
|
||||
Usage:
|
||||
async with async_transaction_scope() as db:
|
||||
user = await user_crud.create(db, obj_in=user_create)
|
||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def check_async_database_health() -> bool:
|
||||
"""
|
||||
Check if async database connection is healthy.
|
||||
Returns True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
async with async_transaction_scope() as db:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async database health check failed: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def init_async_db() -> None:
|
||||
"""
|
||||
Initialize async database tables.
|
||||
|
||||
This creates all tables defined in the models.
|
||||
Should only be used in development or testing.
|
||||
In production, use Alembic migrations.
|
||||
"""
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Async database tables created")
|
||||
|
||||
|
||||
async def close_async_db() -> None:
|
||||
"""
|
||||
Close all async database connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await async_engine.dispose()
|
||||
logger.info("Async database connections closed")
|
||||
281
backend/app/core/exceptions.py
Normal file
281
backend/app/core/exceptions.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Custom exceptions and global exception handlers for the API.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Union, List
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.errors import ErrorCode, ErrorDetail, ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIException(HTTPException):
|
||||
"""
|
||||
Base exception class with error code support.
|
||||
|
||||
This exception provides a standardized way to raise HTTP exceptions
|
||||
with machine-readable error codes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
field: Optional[str] = None,
|
||||
headers: Optional[dict] = None
|
||||
):
|
||||
self.error_code = error_code
|
||||
self.field = field
|
||||
self.message = message
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=message,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationError(APIException):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Authentication failed",
|
||||
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
|
||||
field: Optional[str] = None
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field,
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
|
||||
class AuthorizationError(APIException):
|
||||
"""Raised when user lacks required permissions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Insufficient permissions",
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
class NotFoundError(APIException):
|
||||
"""Raised when a resource is not found."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Resource not found",
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
class DuplicateError(APIException):
|
||||
"""Raised when attempting to create a duplicate resource."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Resource already exists",
|
||||
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
|
||||
field: Optional[str] = None
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
)
|
||||
|
||||
|
||||
class ValidationException(APIException):
|
||||
"""Raised when input validation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Validation error",
|
||||
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
|
||||
field: Optional[str] = None
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
)
|
||||
|
||||
|
||||
class DatabaseError(APIException):
|
||||
"""Raised when a database operation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Database operation failed",
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
# Global exception handlers
|
||||
|
||||
|
||||
async def api_exception_handler(request: Request, exc: APIException) -> JSONResponse:
|
||||
"""
|
||||
Handler for APIException and its subclasses.
|
||||
|
||||
Returns a standardized error response with error code and message.
|
||||
"""
|
||||
logger.warning(
|
||||
f"API exception: {exc.error_code} - {exc.message} "
|
||||
f"(status: {exc.status_code}, path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=exc.error_code,
|
||||
message=exc.message,
|
||||
field=exc.field
|
||||
)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: Union[RequestValidationError, ValidationError]
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Handler for Pydantic validation errors.
|
||||
|
||||
Converts Pydantic validation errors to standardized error response format.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if isinstance(exc, RequestValidationError):
|
||||
validation_errors = exc.errors()
|
||||
else:
|
||||
validation_errors = exc.errors()
|
||||
|
||||
for error in validation_errors:
|
||||
# Extract field name from error location
|
||||
field = None
|
||||
if error.get("loc") and len(error["loc"]) > 1:
|
||||
# Skip 'body' or 'query' prefix in location
|
||||
field = ".".join(str(x) for x in error["loc"][1:])
|
||||
|
||||
errors.append(ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message=error["msg"],
|
||||
field=field
|
||||
))
|
||||
|
||||
logger.warning(
|
||||
f"Validation error: {len(errors)} errors "
|
||||
f"(path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(errors=errors)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=error_response.model_dump()
|
||||
)
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
"""
|
||||
Handler for standard HTTPException.
|
||||
|
||||
Converts standard FastAPI HTTPException to standardized error response format.
|
||||
"""
|
||||
# Map status codes to error codes
|
||||
status_code_to_error_code = {
|
||||
400: ErrorCode.INVALID_INPUT,
|
||||
401: ErrorCode.AUTHENTICATION_REQUIRED,
|
||||
403: ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
404: ErrorCode.NOT_FOUND,
|
||||
405: ErrorCode.METHOD_NOT_ALLOWED,
|
||||
429: ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
500: ErrorCode.INTERNAL_ERROR,
|
||||
}
|
||||
|
||||
error_code = status_code_to_error_code.get(
|
||||
exc.status_code,
|
||||
ErrorCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} "
|
||||
f"(path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=error_code,
|
||||
message=str(exc.detail)
|
||||
)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
)
|
||||
|
||||
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""
|
||||
Handler for unhandled exceptions.
|
||||
|
||||
Logs the full exception and returns a generic error response to avoid
|
||||
leaking sensitive information in production.
|
||||
"""
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__} - {str(exc)} "
|
||||
f"(path: {request.url.path})",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# In production, don't expose internal error details
|
||||
from app.core.config import settings
|
||||
if settings.ENVIRONMENT == "production":
|
||||
message = "An internal error occurred. Please try again later."
|
||||
else:
|
||||
message = f"{type(exc).__name__}: {str(exc)}"
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=message
|
||||
)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=error_response.model_dump()
|
||||
)
|
||||
@@ -1,8 +1,15 @@
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy import func, asc, desc
|
||||
from app.core.database import Base
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
@@ -20,20 +27,66 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
self.model = model
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[ModelType]:
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
"""Get a single record by ID with UUID validation."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return db.query(self.model).filter(self.model.id == uuid_obj).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_multi(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
"""Get multiple records with pagination validation."""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
"""Create a new record with error handling."""
|
||||
try:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -42,21 +95,210 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def remove(self, db: Session, *, id: str) -> ModelType:
|
||||
obj = db.query(self.model).get(id)
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
return None
|
||||
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_multi_with_total(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
"""
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by (must be a valid model attribute)
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = db.query(self.model)
|
||||
|
||||
# Exclude soft-deleted records by default
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
query = query.filter(self.model.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field) and value is not None:
|
||||
query = query.filter(getattr(self.model, field) == value)
|
||||
|
||||
# Get total count (before pagination)
|
||||
total = query.count()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(self.model, sort_by):
|
||||
sort_column = getattr(self.model, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(desc(sort_column))
|
||||
else:
|
||||
query = query.order_by(asc(sort_column))
|
||||
|
||||
# Apply pagination
|
||||
items = query.offset(skip).limit(limit).all()
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
"""
|
||||
Soft delete a record by setting deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
||||
return None
|
||||
|
||||
# Check if model supports soft deletes
|
||||
if not hasattr(self.model, 'deleted_at'):
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
|
||||
# Set deleted_at timestamp
|
||||
obj.deleted_at = datetime.now(timezone.utc)
|
||||
db.add(obj)
|
||||
db.commit()
|
||||
db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def restore(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
"""
|
||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
# Validate UUID format
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find the soft-deleted record
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
obj = db.query(self.model).filter(
|
||||
self.model.id == uuid_obj,
|
||||
self.model.deleted_at.isnot(None)
|
||||
).first()
|
||||
else:
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
|
||||
return None
|
||||
|
||||
# Clear deleted_at timestamp
|
||||
obj.deleted_at = None
|
||||
db.add(obj)
|
||||
db.commit()
|
||||
db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
228
backend/app/crud/base_async.py
Normal file
228
backend/app/crud/base_async.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# app/crud/base_async.py
|
||||
"""
|
||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
|
||||
from app.core.database_async import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""Async CRUD operations for a model."""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]:
|
||||
"""Get a single record by ID with UUID validation."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
"""Get multiple records with pagination validation."""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).offset(skip).limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
"""Create a new record with error handling."""
|
||||
try:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
return None
|
||||
|
||||
await db.delete(obj)
|
||||
await db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_multi_with_total(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
"""
|
||||
Get multiple records with total count for pagination.
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Get total count
|
||||
count_result = await db.execute(
|
||||
select(func.count(self.model.id))
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Get paginated items
|
||||
items_result = await db.execute(
|
||||
select(self.model).offset(skip).limit(limit)
|
||||
)
|
||||
items = list(items_result.scalars().all())
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
"""Get total count of records."""
|
||||
try:
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
"""Check if a record exists by ID."""
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
@@ -1,10 +1,14 @@
|
||||
# app/crud/user.py
|
||||
from typing import Optional, Union, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.core.auth import get_password_hash
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
@@ -12,19 +16,33 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
|
||||
def create(self, db: Session, *, obj_in: UserCreate) -> User:
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
"""Create a new user with password hashing and error handling."""
|
||||
try:
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
||||
logger.error(f"Integrity error creating user: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
||||
@@ -19,7 +19,7 @@ def init_db(db: Session) -> Optional[UserCreate]:
|
||||
"""
|
||||
# Use default values if not set in environment variables
|
||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "admin123"
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "Admin123!Change"
|
||||
|
||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
||||
logger.warning(
|
||||
|
||||
@@ -1,17 +1,35 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, status, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.api.main import api_router
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db, check_database_health
|
||||
from app.core.exceptions import (
|
||||
APIException,
|
||||
api_exception_handler,
|
||||
validation_exception_handler,
|
||||
http_exception_handler,
|
||||
unhandled_exception_handler
|
||||
)
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize rate limiter
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
logger.info(f"Starting app!!!")
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
@@ -19,16 +37,68 @@ app = FastAPI(
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
||||
)
|
||||
|
||||
# Set up CORS middleware
|
||||
# Add rate limiter state to app
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Register custom exception handlers (order matters - most specific first)
|
||||
app.add_exception_handler(APIException, api_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
# Set up CORS middleware with explicit allowed methods and headers
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Accept",
|
||||
"Origin",
|
||||
"User-Agent",
|
||||
"DNT",
|
||||
"Cache-Control",
|
||||
"X-Requested-With",
|
||||
], # Explicit headers only
|
||||
expose_headers=["Content-Length"],
|
||||
max_age=600, # Cache preflight requests for 10 minutes
|
||||
)
|
||||
|
||||
|
||||
# Add security headers middleware
|
||||
@app.middleware("http")
|
||||
async def add_security_headers(request: Request, call_next):
|
||||
"""Add security headers to all responses"""
|
||||
response = await call_next(request)
|
||||
|
||||
# Prevent clickjacking
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
|
||||
# Prevent MIME type sniffing
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
# Enable XSS protection
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
|
||||
# Enforce HTTPS in production
|
||||
if settings.ENVIRONMENT == "production":
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
# Content Security Policy
|
||||
response.headers["Content-Security-Policy"] = "default-src 'self'; frame-ancestors 'none'"
|
||||
|
||||
# Permissions Policy (formerly Feature Policy)
|
||||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Referrer Policy
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root():
|
||||
return """
|
||||
@@ -45,4 +115,59 @@ async def root():
|
||||
"""
|
||||
|
||||
|
||||
@app.get(
|
||||
"/health",
|
||||
summary="Health Check",
|
||||
description="Check the health status of the API and its dependencies",
|
||||
response_description="Health status information",
|
||||
tags=["Health"],
|
||||
operation_id="health_check"
|
||||
)
|
||||
async def health_check() -> JSONResponse:
|
||||
"""
|
||||
Health check endpoint for monitoring and load balancers.
|
||||
|
||||
Returns:
|
||||
JSONResponse: Health status with the following information:
|
||||
- status: Overall health status ("healthy" or "unhealthy")
|
||||
- timestamp: Current server timestamp (ISO 8601 format)
|
||||
- version: API version
|
||||
- environment: Current environment (development, staging, production)
|
||||
- database: Database connectivity status
|
||||
"""
|
||||
health_status: Dict[str, Any] = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"version": settings.VERSION,
|
||||
"environment": settings.ENVIRONMENT,
|
||||
"checks": {}
|
||||
}
|
||||
|
||||
response_status = status.HTTP_200_OK
|
||||
|
||||
# Database health check using dedicated health check function
|
||||
try:
|
||||
db_healthy = check_database_health()
|
||||
if db_healthy:
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "Database connection successful"
|
||||
}
|
||||
else:
|
||||
raise Exception("Database health check returned unhealthy status")
|
||||
except Exception as e:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Database connection failed: {str(e)}"
|
||||
}
|
||||
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
logger.error(f"Health check failed - database error: {e}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response_status,
|
||||
content=health_status
|
||||
)
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from sqlalchemy import Column, String, JSON, Boolean
|
||||
from sqlalchemy import Column, String, Boolean, DateTime
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
@@ -6,14 +7,15 @@ from .base import Base, TimestampMixin, UUIDMixin
|
||||
class User(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email = Column(String, unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String, nullable=False)
|
||||
first_name = Column(String, nullable=False, default="user")
|
||||
last_name = Column(String, nullable=True)
|
||||
phone_number = Column(String)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False)
|
||||
preferences = Column(JSON)
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
first_name = Column(String(100), nullable=False, default="user")
|
||||
last_name = Column(String(100), nullable=True)
|
||||
phone_number = Column(String(20))
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False, index=True)
|
||||
preferences = Column(JSONB)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
168
backend/app/schemas/common.py
Normal file
168
backend/app/schemas/common.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Common schemas used across the API for pagination, responses, filtering, and sorting.
|
||||
"""
|
||||
from typing import Generic, TypeVar, List, Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from math import ceil
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order options."""
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class PaginationParams(BaseModel):
|
||||
"""Parameters for pagination."""
|
||||
|
||||
page: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Page number (1-indexed)"
|
||||
)
|
||||
limit: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of items per page (max 100)"
|
||||
)
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
"""Calculate the offset for database queries."""
|
||||
return (self.page - 1) * self.limit
|
||||
|
||||
@property
|
||||
def skip(self) -> int:
|
||||
"""Alias for offset (compatibility with existing code)."""
|
||||
return self.offset
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"page": 1,
|
||||
"limit": 20
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class SortParams(BaseModel):
|
||||
"""Parameters for sorting."""
|
||||
|
||||
sort_by: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Field name to sort by"
|
||||
)
|
||||
sort_order: SortOrder = Field(
|
||||
default=SortOrder.ASC,
|
||||
description="Sort order (asc or desc)"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"sort_by": "created_at",
|
||||
"sort_order": "desc"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PaginationMeta(BaseModel):
|
||||
"""Metadata for paginated responses."""
|
||||
|
||||
total: int = Field(..., description="Total number of items")
|
||||
page: int = Field(..., description="Current page number")
|
||||
page_size: int = Field(..., description="Number of items in current page")
|
||||
total_pages: int = Field(..., description="Total number of pages")
|
||||
has_next: bool = Field(..., description="Whether there is a next page")
|
||||
has_prev: bool = Field(..., description="Whether there is a previous page")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"total": 150,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
data: List[T] = Field(..., description="List of items")
|
||||
pagination: PaginationMeta = Field(..., description="Pagination metadata")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"data": [
|
||||
{"id": "123", "name": "Example Item"}
|
||||
],
|
||||
"pagination": {
|
||||
"total": 150,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Simple message response."""
|
||||
|
||||
success: bool = Field(default=True, description="Operation success status")
|
||||
message: str = Field(..., description="Human-readable message")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Operation completed successfully"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_pagination_meta(
|
||||
total: int,
|
||||
page: int,
|
||||
limit: int,
|
||||
items_count: int
|
||||
) -> PaginationMeta:
|
||||
"""
|
||||
Helper function to create pagination metadata.
|
||||
|
||||
Args:
|
||||
total: Total number of items
|
||||
page: Current page number
|
||||
limit: Items per page
|
||||
items_count: Number of items in current page
|
||||
|
||||
Returns:
|
||||
PaginationMeta object with calculated values
|
||||
"""
|
||||
total_pages = ceil(total / limit) if limit > 0 else 0
|
||||
|
||||
return PaginationMeta(
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=items_count,
|
||||
total_pages=total_pages,
|
||||
has_next=page < total_pages,
|
||||
has_prev=page > 1
|
||||
)
|
||||
85
backend/app/schemas/errors.py
Normal file
85
backend/app/schemas/errors.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Error schemas for standardized API error responses.
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
"""Standard error codes for the API."""
|
||||
|
||||
# Authentication errors (AUTH_xxx)
|
||||
INVALID_CREDENTIALS = "AUTH_001"
|
||||
TOKEN_EXPIRED = "AUTH_002"
|
||||
TOKEN_INVALID = "AUTH_003"
|
||||
INSUFFICIENT_PERMISSIONS = "AUTH_004"
|
||||
USER_INACTIVE = "AUTH_005"
|
||||
AUTHENTICATION_REQUIRED = "AUTH_006"
|
||||
|
||||
# User errors (USER_xxx)
|
||||
USER_NOT_FOUND = "USER_001"
|
||||
USER_ALREADY_EXISTS = "USER_002"
|
||||
USER_CREATION_FAILED = "USER_003"
|
||||
USER_UPDATE_FAILED = "USER_004"
|
||||
USER_DELETION_FAILED = "USER_005"
|
||||
|
||||
# Validation errors (VAL_xxx)
|
||||
VALIDATION_ERROR = "VAL_001"
|
||||
INVALID_PASSWORD = "VAL_002"
|
||||
INVALID_EMAIL = "VAL_003"
|
||||
INVALID_PHONE_NUMBER = "VAL_004"
|
||||
INVALID_UUID = "VAL_005"
|
||||
INVALID_INPUT = "VAL_006"
|
||||
|
||||
# Database errors (DB_xxx)
|
||||
DATABASE_ERROR = "DB_001"
|
||||
DUPLICATE_ENTRY = "DB_002"
|
||||
FOREIGN_KEY_VIOLATION = "DB_003"
|
||||
RECORD_NOT_FOUND = "DB_004"
|
||||
|
||||
# Generic errors (SYS_xxx)
|
||||
INTERNAL_ERROR = "SYS_001"
|
||||
NOT_FOUND = "SYS_002"
|
||||
METHOD_NOT_ALLOWED = "SYS_003"
|
||||
RATE_LIMIT_EXCEEDED = "SYS_004"
|
||||
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Detailed information about a single error."""
|
||||
|
||||
code: ErrorCode = Field(..., description="Machine-readable error code")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
field: Optional[str] = Field(None, description="Field name if error is field-specific")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"code": "VAL_002",
|
||||
"message": "Password must be at least 8 characters long",
|
||||
"field": "password"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Standardized error response format."""
|
||||
|
||||
success: bool = Field(default=False, description="Always false for error responses")
|
||||
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": False,
|
||||
"errors": [
|
||||
{
|
||||
"code": "AUTH_001",
|
||||
"message": "Invalid email or password",
|
||||
"field": None
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -123,7 +123,26 @@ class TokenData(BaseModel):
|
||||
is_superuser: bool = False
|
||||
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Schema for changing password (requires current password)."""
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Schema for resetting password (via email token)."""
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
|
||||
1
backend/coverage.json
Normal file
1
backend/coverage.json
Normal file
File diff suppressed because one or more lines are too long
@@ -12,10 +12,8 @@ alembic>=1.14.1
|
||||
psycopg2-binary>=2.9.9
|
||||
asyncpg>=0.29.0
|
||||
aiosqlite==0.21.0
|
||||
# Security and authentication
|
||||
python-jose>=3.4.0
|
||||
passlib>=1.7.4
|
||||
bcrypt>=4.1.2
|
||||
|
||||
# Environment configuration
|
||||
python-dotenv>=1.0.1
|
||||
|
||||
# API documentation
|
||||
@@ -26,6 +24,9 @@ ujson>=5.9.0
|
||||
starlette>=0.40.0
|
||||
starlette-csrf>=1.4.5
|
||||
|
||||
# Rate limiting
|
||||
slowapi>=0.1.9
|
||||
|
||||
# Utilities
|
||||
httpx>=0.27.0
|
||||
tenacity>=8.2.3
|
||||
@@ -44,9 +45,11 @@ isort>=5.13.2
|
||||
flake8>=7.0.0
|
||||
mypy>=1.8.0
|
||||
|
||||
# Security
|
||||
# Security and authentication (pinned for reproducibility)
|
||||
python-jose==3.4.0
|
||||
passlib==1.7.4
|
||||
bcrypt==4.2.1
|
||||
cryptography==44.0.1
|
||||
passlib==1.7.4
|
||||
|
||||
# Testing utilities
|
||||
freezegun~=1.5.1
|
||||
@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.routes.auth import router as auth_router
|
||||
from app.api.routes.users import router as users_router
|
||||
from app.core.auth import get_password_hash
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
@@ -29,6 +30,7 @@ def app(override_get_db):
|
||||
"""Create a FastAPI test application with overridden dependencies."""
|
||||
app = FastAPI()
|
||||
app.include_router(auth_router, prefix="/auth", tags=["auth"])
|
||||
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
|
||||
|
||||
# Override the get_db dependency
|
||||
app.dependency_overrides[get_db] = lambda: override_get_db
|
||||
@@ -280,9 +282,9 @@ class TestChangePassword:
|
||||
|
||||
# Mock password change to return success
|
||||
with patch.object(AuthService, 'change_password', return_value=True):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/change-password",
|
||||
# Test request (new endpoint)
|
||||
response = client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
json={
|
||||
"current_password": "OldPassword123",
|
||||
"new_password": "NewPassword123"
|
||||
@@ -291,7 +293,8 @@ class TestChangePassword:
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200
|
||||
assert "success" in response.json()["message"].lower()
|
||||
assert response.json()["success"] is True
|
||||
assert "message" in response.json()
|
||||
|
||||
# Clean up override
|
||||
if get_current_user in app.dependency_overrides:
|
||||
@@ -312,18 +315,20 @@ class TestChangePassword:
|
||||
# Mock password change to raise error
|
||||
with patch.object(AuthService, 'change_password',
|
||||
side_effect=AuthenticationError("Current password is incorrect")):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/change-password",
|
||||
# Test request (new endpoint)
|
||||
response = client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
json={
|
||||
"current_password": "WrongPassword",
|
||||
"new_password": "NewPassword123"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 400
|
||||
assert "incorrect" in response.json()["detail"].lower()
|
||||
# Assertions - Now returns standardized error response
|
||||
assert response.status_code == 403
|
||||
# The response has standardized error format
|
||||
data = response.json()
|
||||
assert "detail" in data or "errors" in data
|
||||
|
||||
# Clean up override
|
||||
if get_current_user in app.dependency_overrides:
|
||||
|
||||
184
backend/tests/api/routes/test_health.py
Normal file
184
backend/tests/api/routes/test_health.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# tests/api/routes/test_health.py
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import get_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a FastAPI test client for the main app with mocked database."""
|
||||
# Mock check_database_health to avoid connecting to the actual database
|
||||
with patch("app.main.check_database_health") as mock_health_check:
|
||||
# By default, return True (healthy)
|
||||
mock_health_check.return_value = True
|
||||
yield TestClient(app)
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for the /health endpoint"""
|
||||
|
||||
def test_health_check_healthy(self, client):
|
||||
"""Test that health check returns healthy when database is accessible"""
|
||||
response = client.get("/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
# Check required fields
|
||||
assert "status" in data
|
||||
assert data["status"] == "healthy"
|
||||
assert "timestamp" in data
|
||||
assert "version" in data
|
||||
assert "environment" in data
|
||||
assert "checks" in data
|
||||
|
||||
# Verify timestamp format (ISO 8601)
|
||||
assert data["timestamp"].endswith("Z")
|
||||
# Verify it's a valid datetime
|
||||
datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00"))
|
||||
|
||||
# Check database health
|
||||
assert "database" in data["checks"]
|
||||
assert data["checks"]["database"]["status"] == "healthy"
|
||||
assert "message" in data["checks"]["database"]
|
||||
|
||||
def test_health_check_response_structure(self, client):
|
||||
"""Test that health check response has correct structure"""
|
||||
response = client.get("/health")
|
||||
data = response.json()
|
||||
|
||||
# Verify top-level structure
|
||||
assert isinstance(data["status"], str)
|
||||
assert isinstance(data["timestamp"], str)
|
||||
assert isinstance(data["version"], str)
|
||||
assert isinstance(data["environment"], str)
|
||||
assert isinstance(data["checks"], dict)
|
||||
|
||||
# Verify database check structure
|
||||
db_check = data["checks"]["database"]
|
||||
assert isinstance(db_check["status"], str)
|
||||
assert isinstance(db_check["message"], str)
|
||||
|
||||
def test_health_check_version_matches_settings(self, client):
|
||||
"""Test that health check returns correct version from settings"""
|
||||
from app.core.config import settings
|
||||
|
||||
response = client.get("/health")
|
||||
data = response.json()
|
||||
|
||||
assert data["version"] == settings.VERSION
|
||||
|
||||
def test_health_check_environment_matches_settings(self, client):
|
||||
"""Test that health check returns correct environment from settings"""
|
||||
from app.core.config import settings
|
||||
|
||||
response = client.get("/health")
|
||||
data = response.json()
|
||||
|
||||
assert data["environment"] == settings.ENVIRONMENT
|
||||
|
||||
def test_health_check_database_connection_failure(self):
|
||||
"""Test that health check returns unhealthy when database is not accessible"""
|
||||
# Mock check_database_health to return False (unhealthy)
|
||||
with patch("app.main.check_database_health") as mock_health_check:
|
||||
mock_health_check.return_value = False
|
||||
|
||||
test_client = TestClient(app)
|
||||
response = test_client.get("/health")
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
data = response.json()
|
||||
|
||||
# Check overall status
|
||||
assert data["status"] == "unhealthy"
|
||||
|
||||
# Check database status
|
||||
assert "database" in data["checks"]
|
||||
assert data["checks"]["database"]["status"] == "unhealthy"
|
||||
assert "failed" in data["checks"]["database"]["message"].lower()
|
||||
|
||||
def test_health_check_timestamp_recent(self, client):
|
||||
"""Test that health check timestamp is recent (within last minute)"""
|
||||
before = datetime.utcnow()
|
||||
response = client.get("/health")
|
||||
after = datetime.utcnow()
|
||||
|
||||
data = response.json()
|
||||
timestamp = datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00"))
|
||||
|
||||
# Timestamp should be between before and after
|
||||
assert before <= timestamp.replace(tzinfo=None) <= after
|
||||
|
||||
def test_health_check_no_authentication_required(self, client):
|
||||
"""Test that health check does not require authentication"""
|
||||
# Make request without any authentication headers
|
||||
response = client.get("/health")
|
||||
|
||||
# Should succeed without authentication
|
||||
assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE]
|
||||
|
||||
def test_health_check_idempotent(self, client):
|
||||
"""Test that multiple health checks return consistent results"""
|
||||
response1 = client.get("/health")
|
||||
response2 = client.get("/health")
|
||||
|
||||
# Both should have same status code (either both healthy or both unhealthy)
|
||||
assert response1.status_code == response2.status_code
|
||||
|
||||
data1 = response1.json()
|
||||
data2 = response2.json()
|
||||
|
||||
# Same overall health status
|
||||
assert data1["status"] == data2["status"]
|
||||
|
||||
# Same version and environment
|
||||
assert data1["version"] == data2["version"]
|
||||
assert data1["environment"] == data2["environment"]
|
||||
|
||||
# Same database check status
|
||||
assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"]
|
||||
|
||||
def test_health_check_content_type(self, client):
|
||||
"""Test that health check returns JSON content type"""
|
||||
response = client.get("/health")
|
||||
|
||||
assert "application/json" in response.headers["content-type"]
|
||||
|
||||
|
||||
class TestHealthEndpointEdgeCases:
|
||||
"""Edge case tests for the /health endpoint"""
|
||||
|
||||
def test_health_check_with_query_parameters(self, client):
|
||||
"""Test that health check ignores query parameters"""
|
||||
response = client.get("/health?foo=bar&baz=qux")
|
||||
|
||||
# Should still work with query params
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
def test_health_check_method_not_allowed(self, client):
|
||||
"""Test that POST/PUT/DELETE are not allowed on health endpoint"""
|
||||
# POST should not be allowed
|
||||
response = client.post("/health")
|
||||
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
|
||||
|
||||
# PUT should not be allowed
|
||||
response = client.put("/health")
|
||||
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
|
||||
|
||||
# DELETE should not be allowed
|
||||
response = client.delete("/health")
|
||||
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
|
||||
|
||||
def test_health_check_with_accept_header(self, client):
|
||||
"""Test that health check works with different Accept headers"""
|
||||
response = client.get("/health", headers={"Accept": "application/json"})
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response = client.get("/health", headers={"Accept": "*/*"})
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
196
backend/tests/api/routes/test_rate_limiting.py
Normal file
196
backend/tests/api/routes/test_rate_limiting.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# tests/api/routes/test_rate_limiting.py
|
||||
import pytest
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.api.routes.auth import router as auth_router, limiter
|
||||
from app.api.routes.users import router as users_router
|
||||
from app.core.database import get_db
|
||||
|
||||
|
||||
# Mock the get_db dependency
|
||||
@pytest.fixture
|
||||
def override_get_db():
|
||||
"""Override get_db dependency for testing."""
|
||||
mock_db = MagicMock()
|
||||
return mock_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(override_get_db):
|
||||
"""Create a FastAPI test application with rate limiting."""
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
app = FastAPI()
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
app.include_router(auth_router, prefix="/auth", tags=["auth"])
|
||||
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
|
||||
|
||||
# Override the get_db dependency
|
||||
app.dependency_overrides[get_db] = lambda: override_get_db
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a FastAPI test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestRegisterRateLimiting:
|
||||
"""Tests for rate limiting on /register endpoint"""
|
||||
|
||||
def test_register_rate_limit_blocks_over_limit(self, client):
|
||||
"""Test that requests over rate limit are blocked"""
|
||||
from app.services.auth_service import AuthService
|
||||
from app.models.user import User
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(AuthService, 'create_user', return_value=mock_user):
|
||||
user_data = {
|
||||
"email": f"test{uuid.uuid4()}@example.com",
|
||||
"password": "TestPassword123!",
|
||||
"first_name": "Test",
|
||||
"last_name": "User"
|
||||
}
|
||||
|
||||
# Make 6 requests (limit is 5/minute)
|
||||
responses = []
|
||||
for i in range(6):
|
||||
response = client.post("/auth/register", json=user_data)
|
||||
responses.append(response)
|
||||
|
||||
# Last request should be rate limited
|
||||
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
|
||||
class TestLoginRateLimiting:
|
||||
"""Tests for rate limiting on /login endpoint"""
|
||||
|
||||
def test_login_rate_limit_blocks_over_limit(self, client):
|
||||
"""Test that login requests over rate limit are blocked"""
|
||||
from app.services.auth_service import AuthService
|
||||
|
||||
with patch.object(AuthService, 'authenticate_user', return_value=None):
|
||||
login_data = {
|
||||
"email": "test@example.com",
|
||||
"password": "wrong_password"
|
||||
}
|
||||
|
||||
# Make 11 requests (limit is 10/minute)
|
||||
responses = []
|
||||
for i in range(11):
|
||||
response = client.post("/auth/login", json=login_data)
|
||||
responses.append(response)
|
||||
|
||||
# Last request should be rate limited
|
||||
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
|
||||
class TestRefreshTokenRateLimiting:
|
||||
"""Tests for rate limiting on /refresh endpoint"""
|
||||
|
||||
def test_refresh_rate_limit_blocks_over_limit(self, client):
|
||||
"""Test that refresh requests over rate limit are blocked"""
|
||||
from app.services.auth_service import AuthService
|
||||
from app.core.auth import TokenInvalidError
|
||||
|
||||
with patch.object(AuthService, 'refresh_tokens', side_effect=TokenInvalidError("Invalid")):
|
||||
refresh_data = {
|
||||
"refresh_token": "invalid_token"
|
||||
}
|
||||
|
||||
# Make 31 requests (limit is 30/minute)
|
||||
responses = []
|
||||
for i in range(31):
|
||||
response = client.post("/auth/refresh", json=refresh_data)
|
||||
responses.append(response)
|
||||
|
||||
# Last request should be rate limited
|
||||
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
|
||||
class TestChangePasswordRateLimiting:
|
||||
"""Tests for rate limiting on /change-password endpoint"""
|
||||
|
||||
def test_change_password_rate_limit_blocks_over_limit(self, client):
|
||||
"""Test that change password requests over rate limit are blocked"""
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
# Mock current user
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# Override get_current_user dependency in the app
|
||||
test_app = client.app
|
||||
test_app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
|
||||
with patch.object(AuthService, 'change_password', side_effect=AuthenticationError("Invalid password")):
|
||||
password_data = {
|
||||
"current_password": "wrong_password",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
|
||||
# Make 6 requests (limit is 5/minute) - using new endpoint
|
||||
responses = []
|
||||
for i in range(6):
|
||||
response = client.patch("/api/v1/users/me/password", json=password_data)
|
||||
responses.append(response)
|
||||
|
||||
# Last request should be rate limited
|
||||
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
# Clean up override
|
||||
test_app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestRateLimitErrorResponse:
|
||||
"""Tests for rate limit error response format"""
|
||||
|
||||
def test_rate_limit_error_response_format(self, client):
|
||||
"""Test that rate limit error has correct format"""
|
||||
from app.services.auth_service import AuthService
|
||||
|
||||
with patch.object(AuthService, 'authenticate_user', return_value=None):
|
||||
login_data = {
|
||||
"email": "test@example.com",
|
||||
"password": "password"
|
||||
}
|
||||
|
||||
# Exceed rate limit
|
||||
for i in range(11):
|
||||
response = client.post("/auth/login", json=login_data)
|
||||
|
||||
# Check error response
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert "detail" in response.json() or "error" in response.json()
|
||||
487
backend/tests/api/routes/test_users.py
Normal file
487
backend/tests/api/routes/test_users.py
Normal file
@@ -0,0 +1,487 @@
|
||||
# tests/api/routes/test_users.py
|
||||
"""
|
||||
Tests for user management endpoints.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.routes.users import router as users_router
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def override_get_db(db_session):
|
||||
"""Override get_db dependency for testing."""
|
||||
return db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(override_get_db):
|
||||
"""Create a FastAPI test application."""
|
||||
app = FastAPI()
|
||||
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
|
||||
|
||||
# Override the get_db dependency
|
||||
app.dependency_overrides[get_db] = lambda: override_get_db
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a FastAPI test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regular_user():
|
||||
"""Create a mock regular user."""
|
||||
return User(
|
||||
id=uuid.uuid4(),
|
||||
email="regular@example.com",
|
||||
password_hash="hashed_password",
|
||||
first_name="Regular",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user():
|
||||
"""Create a mock superuser."""
|
||||
return User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@example.com",
|
||||
password_hash="hashed_password",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
"""Tests for the list_users endpoint."""
|
||||
|
||||
def test_list_users_as_superuser(self, client, app, super_user, regular_user, db_session):
|
||||
"""Test that superusers can list all users."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
# Override auth dependency
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
# Mock user_crud to return test data
|
||||
mock_users = [regular_user for _ in range(3)]
|
||||
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users, 3)):
|
||||
response = client.get("/api/v1/users?page=1&limit=20")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
assert len(data["data"]) == 3
|
||||
assert data["pagination"]["total"] == 3
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
def test_list_users_pagination(self, client, app, super_user, regular_user, db_session):
|
||||
"""Test pagination parameters for list users."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
# Mock user_crud
|
||||
mock_users = [regular_user for _ in range(10)]
|
||||
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users[:5], 10)):
|
||||
response = client.get("/api/v1/users?page=1&limit=5")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["page"] == 1
|
||||
assert data["pagination"]["page_size"] == 5
|
||||
assert data["pagination"]["total"] == 10
|
||||
assert data["pagination"]["total_pages"] == 2
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
|
||||
class TestGetCurrentUserProfile:
|
||||
"""Tests for the get_current_user_profile endpoint."""
|
||||
|
||||
def test_get_current_user_profile(self, client, app, regular_user):
|
||||
"""Test getting current user's profile."""
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
response = client.get("/api/v1/users/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == regular_user.email
|
||||
assert data["first_name"] == regular_user.first_name
|
||||
assert data["last_name"] == regular_user.last_name
|
||||
assert "password" not in data
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestUpdateCurrentUser:
|
||||
"""Tests for the update_current_user endpoint."""
|
||||
|
||||
def test_update_current_user_success(self, client, app, regular_user, db_session):
|
||||
"""Test successful profile update."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
"/api/v1/users/me",
|
||||
json={"first_name": "Updated", "last_name": "Name"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "Updated"
|
||||
assert data["last_name"] == "Name"
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_update_current_user_extra_fields_ignored(self, client, app, regular_user, db_session):
|
||||
"""Test that extra fields like is_superuser are ignored by schema validation."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
# Create updated user without is_superuser changed
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="Updated",
|
||||
last_name=regular_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False, # Should remain False
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
"/api/v1/users/me",
|
||||
json={"first_name": "Updated", "is_superuser": True} # is_superuser will be ignored
|
||||
)
|
||||
|
||||
# Request should succeed but is_superuser should be unchanged
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_superuser"] is False
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestGetUserById:
|
||||
"""Tests for the get_user_by_id endpoint."""
|
||||
|
||||
def test_get_own_profile(self, client, app, regular_user, db_session):
|
||||
"""Test that users can get their own profile."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=regular_user):
|
||||
response = client.get(f"/api/v1/users/{regular_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == regular_user.email
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_other_user_as_regular_user(self, client, app, regular_user):
|
||||
"""Test that regular users cannot view other users."""
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
other_user_id = uuid.uuid4()
|
||||
response = client.get(f"/api/v1/users/{other_user_id}")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_other_user_as_superuser(self, client, app, super_user, db_session):
|
||||
"""Test that superusers can view any user."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
||||
|
||||
other_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="other@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Other",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=other_user):
|
||||
response = client.get(f"/api/v1/users/{other_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == other_user.email
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_nonexistent_user(self, client, app, super_user, db_session):
|
||||
"""Test getting a user that doesn't exist."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=None):
|
||||
response = client.get(f"/api/v1/users/{uuid.uuid4()}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestUpdateUser:
|
||||
"""Tests for the update_user endpoint."""
|
||||
|
||||
def test_update_own_profile(self, client, app, regular_user, db_session):
|
||||
"""Test that users can update their own profile."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="NewName",
|
||||
last_name=regular_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=regular_user), \
|
||||
patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{regular_user.id}",
|
||||
json={"first_name": "NewName"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "NewName"
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_update_other_user_as_regular_user(self, client, app, regular_user):
|
||||
"""Test that regular users cannot update other users."""
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
other_user_id = uuid.uuid4()
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{other_user_id}",
|
||||
json={"first_name": "NewName"}
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_user_schema_ignores_extra_fields(self, client, app, regular_user, db_session):
|
||||
"""Test that UserUpdate schema ignores extra fields like is_superuser."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
# Updated user with is_superuser unchanged
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="Changed",
|
||||
last_name=regular_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False, # Should remain False
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=regular_user), \
|
||||
patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{regular_user.id}",
|
||||
json={"first_name": "Changed", "is_superuser": True} # is_superuser ignored
|
||||
)
|
||||
|
||||
# Should succeed, extra field is ignored
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_superuser"] is False
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_superuser_can_update_any_user(self, client, app, super_user, db_session):
|
||||
"""Test that superusers can update any user."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
||||
|
||||
target_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="target@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Target",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
updated_user = User(
|
||||
id=target_user.id,
|
||||
email=target_user.email,
|
||||
password_hash=target_user.password_hash,
|
||||
first_name="Updated",
|
||||
last_name=target_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=target_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=target_user), \
|
||||
patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{target_user.id}",
|
||||
json={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "Updated"
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestDeleteUser:
|
||||
"""Tests for the delete_user endpoint."""
|
||||
|
||||
def test_delete_user_as_superuser(self, client, app, super_user, db_session):
|
||||
"""Test that superusers can delete users."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
target_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="target@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Target",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=target_user), \
|
||||
patch.object(user_crud, 'remove', return_value=target_user):
|
||||
response = client.delete(f"/api/v1/users/{target_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "deleted successfully" in data["message"]
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
def test_delete_nonexistent_user(self, client, app, super_user, db_session):
|
||||
"""Test deleting a user that doesn't exist."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=None):
|
||||
response = client.delete(f"/api/v1/users/{uuid.uuid4()}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
def test_cannot_delete_self(self, client, app, super_user, db_session):
|
||||
"""Test that users cannot delete their own account."""
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
response = client.delete(f"/api/v1/users/{super_user.id}")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
94
backend/tests/api/test_security_headers.py
Normal file
94
backend/tests/api/test_security_headers.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# tests/api/test_security_headers.py
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a FastAPI test client for the main app."""
|
||||
# Mock get_db to avoid database connection issues
|
||||
with patch("app.main.get_db") as mock_get_db:
|
||||
def mock_session_generator():
|
||||
from unittest.mock import MagicMock
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute.return_value = None
|
||||
mock_session.close.return_value = None
|
||||
yield mock_session
|
||||
|
||||
mock_get_db.side_effect = lambda: mock_session_generator()
|
||||
yield TestClient(app)
|
||||
|
||||
|
||||
class TestSecurityHeaders:
|
||||
"""Tests for security headers middleware"""
|
||||
|
||||
def test_x_frame_options_header(self, client):
|
||||
"""Test that X-Frame-Options header is set to DENY"""
|
||||
response = client.get("/health")
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert response.headers["X-Frame-Options"] == "DENY"
|
||||
|
||||
def test_x_content_type_options_header(self, client):
|
||||
"""Test that X-Content-Type-Options header is set to nosniff"""
|
||||
response = client.get("/health")
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_x_xss_protection_header(self, client):
|
||||
"""Test that X-XSS-Protection header is set"""
|
||||
response = client.get("/health")
|
||||
assert "X-XSS-Protection" in response.headers
|
||||
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
||||
|
||||
def test_content_security_policy_header(self, client):
|
||||
"""Test that Content-Security-Policy header is set"""
|
||||
response = client.get("/health")
|
||||
assert "Content-Security-Policy" in response.headers
|
||||
assert "default-src 'self'" in response.headers["Content-Security-Policy"]
|
||||
assert "frame-ancestors 'none'" in response.headers["Content-Security-Policy"]
|
||||
|
||||
def test_permissions_policy_header(self, client):
|
||||
"""Test that Permissions-Policy header is set"""
|
||||
response = client.get("/health")
|
||||
assert "Permissions-Policy" in response.headers
|
||||
assert "geolocation=()" in response.headers["Permissions-Policy"]
|
||||
assert "microphone=()" in response.headers["Permissions-Policy"]
|
||||
assert "camera=()" in response.headers["Permissions-Policy"]
|
||||
|
||||
def test_referrer_policy_header(self, client):
|
||||
"""Test that Referrer-Policy header is set"""
|
||||
response = client.get("/health")
|
||||
assert "Referrer-Policy" in response.headers
|
||||
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
|
||||
|
||||
def test_strict_transport_security_not_in_development(self, client):
|
||||
"""Test that Strict-Transport-Security header is not set in development"""
|
||||
from app.core.config import settings
|
||||
|
||||
# In development, HSTS should not be present
|
||||
if settings.ENVIRONMENT == "development":
|
||||
response = client.get("/health")
|
||||
assert "Strict-Transport-Security" not in response.headers
|
||||
|
||||
def test_security_headers_on_all_endpoints(self, client):
|
||||
"""Test that security headers are present on all endpoints"""
|
||||
# Test health endpoint
|
||||
response = client.get("/health")
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
|
||||
# Test root endpoint
|
||||
response = client.get("/")
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
|
||||
def test_security_headers_on_404(self, client):
|
||||
"""Test that security headers are present even on 404 responses"""
|
||||
response = client.get("/nonexistent-endpoint")
|
||||
assert response.status_code == 404
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
assert "X-XSS-Protection" in response.headers
|
||||
202
backend/tests/core/test_config.py
Normal file
202
backend/tests/core/test_config.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# tests/core/test_config.py
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from app.core.config import Settings
|
||||
|
||||
|
||||
class TestSecretKeyValidation:
|
||||
"""Tests for SECRET_KEY validation"""
|
||||
|
||||
def test_secret_key_too_short_raises_error(self):
|
||||
"""Test that SECRET_KEY shorter than 32 characters raises error"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(SECRET_KEY="short_key", ENVIRONMENT="development")
|
||||
|
||||
# Pydantic Field's min_length validation triggers first
|
||||
assert "at least 32 characters" in str(exc_info.value)
|
||||
|
||||
def test_default_secret_key_in_production_raises_error(self):
|
||||
"""Test that default SECRET_KEY in production raises error"""
|
||||
# Use the exact default value (padded to 32 chars to pass length check)
|
||||
default_key = "your_secret_key_here" + "_" * 12 # Exactly 32 chars
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(SECRET_KEY=default_key, ENVIRONMENT="production")
|
||||
|
||||
assert "must be set to a secure random value in production" in str(exc_info.value)
|
||||
|
||||
def test_default_secret_key_in_development_allows_with_warning(self, caplog):
|
||||
"""Test that default SECRET_KEY in development is allowed but warns"""
|
||||
settings = Settings(SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development")
|
||||
|
||||
assert settings.SECRET_KEY == "your_secret_key_here" + "x" * 14
|
||||
# Note: The warning happens during validation, which we've seen works
|
||||
|
||||
def test_valid_secret_key_accepted(self):
|
||||
"""Test that valid SECRET_KEY is accepted"""
|
||||
valid_key = "a" * 32
|
||||
settings = Settings(SECRET_KEY=valid_key, ENVIRONMENT="production")
|
||||
|
||||
assert settings.SECRET_KEY == valid_key
|
||||
|
||||
|
||||
class TestSuperuserPasswordValidation:
|
||||
"""Tests for FIRST_SUPERUSER_PASSWORD validation"""
|
||||
|
||||
def test_none_password_accepted(self):
|
||||
"""Test that None password is accepted (optional field)"""
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD=None
|
||||
)
|
||||
assert settings.FIRST_SUPERUSER_PASSWORD is None
|
||||
|
||||
def test_password_too_short_raises_error(self):
|
||||
"""Test that password shorter than 12 characters raises error"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="Short1"
|
||||
)
|
||||
|
||||
assert "must be at least 12 characters" in str(exc_info.value)
|
||||
|
||||
def test_weak_password_rejected(self):
|
||||
"""Test that common weak passwords are rejected"""
|
||||
# Test with the exact weak passwords from the validator
|
||||
# These are in the weak_passwords set and should be rejected
|
||||
weak_passwords = ['123456789012'] # Exactly 12 chars, in the weak set
|
||||
|
||||
for weak_pwd in weak_passwords:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD=weak_pwd
|
||||
)
|
||||
# Should get "too weak" message
|
||||
error_str = str(exc_info.value)
|
||||
assert "too weak" in error_str
|
||||
|
||||
def test_password_without_lowercase_rejected(self):
|
||||
"""Test that password without lowercase is rejected"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123"
|
||||
)
|
||||
|
||||
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
|
||||
|
||||
def test_password_without_uppercase_rejected(self):
|
||||
"""Test that password without uppercase is rejected"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="alllowercase123"
|
||||
)
|
||||
|
||||
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
|
||||
|
||||
def test_password_without_digit_rejected(self):
|
||||
"""Test that password without digit is rejected"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="NoDigitsHere"
|
||||
)
|
||||
|
||||
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
|
||||
|
||||
def test_strong_password_accepted(self):
|
||||
"""Test that strong password is accepted"""
|
||||
strong_password = "StrongPassword123!"
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD=strong_password
|
||||
)
|
||||
|
||||
assert settings.FIRST_SUPERUSER_PASSWORD == strong_password
|
||||
|
||||
|
||||
class TestEnvironmentConfiguration:
|
||||
"""Tests for environment-specific configuration"""
|
||||
|
||||
def test_default_environment_is_development(self):
|
||||
"""Test that default environment is development"""
|
||||
settings = Settings(SECRET_KEY="a" * 32)
|
||||
assert settings.ENVIRONMENT == "development"
|
||||
|
||||
def test_environment_can_be_set(self):
|
||||
"""Test that environment can be set to different values"""
|
||||
for env in ["development", "staging", "production"]:
|
||||
settings = Settings(SECRET_KEY="a" * 32, ENVIRONMENT=env)
|
||||
assert settings.ENVIRONMENT == env
|
||||
|
||||
|
||||
class TestDatabaseConfiguration:
|
||||
"""Tests for database URL construction"""
|
||||
|
||||
def test_database_url_construction_from_components(self, monkeypatch):
|
||||
"""Test that database URL is constructed correctly from components"""
|
||||
# Clear .env file influence for this test
|
||||
monkeypatch.delenv("POSTGRES_USER", raising=False)
|
||||
monkeypatch.delenv("POSTGRES_PASSWORD", raising=False)
|
||||
monkeypatch.delenv("POSTGRES_HOST", raising=False)
|
||||
monkeypatch.delenv("POSTGRES_DB", raising=False)
|
||||
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
POSTGRES_USER="testuser",
|
||||
POSTGRES_PASSWORD="testpass",
|
||||
POSTGRES_HOST="testhost",
|
||||
POSTGRES_PORT="5432",
|
||||
POSTGRES_DB="testdb",
|
||||
DATABASE_URL=None # Don't use explicit URL
|
||||
)
|
||||
|
||||
expected_url = "postgresql://testuser:testpass@testhost:5432/testdb"
|
||||
assert settings.database_url == expected_url
|
||||
|
||||
def test_explicit_database_url_used_when_set(self):
|
||||
"""Test that explicit DATABASE_URL is used when provided"""
|
||||
explicit_url = "postgresql://explicit:pass@host:5432/db"
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
DATABASE_URL=explicit_url
|
||||
)
|
||||
|
||||
assert settings.database_url == explicit_url
|
||||
|
||||
|
||||
class TestJWTConfiguration:
|
||||
"""Tests for JWT configuration"""
|
||||
|
||||
def test_token_expiration_defaults(self):
|
||||
"""Test that token expiration defaults are set correctly"""
|
||||
settings = Settings(SECRET_KEY="a" * 32)
|
||||
|
||||
assert settings.ACCESS_TOKEN_EXPIRE_MINUTES == 15 # 15 minutes
|
||||
assert settings.REFRESH_TOKEN_EXPIRE_DAYS == 7 # 7 days
|
||||
|
||||
def test_algorithm_default(self):
|
||||
"""Test that default algorithm is HS256"""
|
||||
settings = Settings(SECRET_KEY="a" * 32)
|
||||
assert settings.ALGORITHM == "HS256"
|
||||
|
||||
|
||||
class TestProjectConfiguration:
|
||||
"""Tests for project-level configuration"""
|
||||
|
||||
def test_project_name_default(self):
|
||||
"""Test that project name is set correctly"""
|
||||
settings = Settings(SECRET_KEY="a" * 32)
|
||||
assert settings.PROJECT_NAME == "App"
|
||||
|
||||
def test_api_version_string(self):
|
||||
"""Test that API version string is correct"""
|
||||
settings = Settings(SECRET_KEY="a" * 32)
|
||||
assert settings.API_V1_STR == "/api/v1"
|
||||
|
||||
def test_version_default(self):
|
||||
"""Test that version is set"""
|
||||
settings = Settings(SECRET_KEY="a" * 32)
|
||||
assert settings.VERSION == "1.0.0"
|
||||
223
backend/tests/test_init_db.py
Normal file
223
backend/tests/test_init_db.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# tests/test_init_db.py
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.init_db import init_db
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
class TestInitDB:
|
||||
"""Tests for database initialization"""
|
||||
|
||||
def test_init_db_creates_superuser_when_not_exists(self, db_session, monkeypatch):
|
||||
"""Test that init_db creates superuser when it doesn't exist"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings to pick up environment variables
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
from app.core.config import settings
|
||||
|
||||
# Mock user_crud to return None (user doesn't exist)
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
|
||||
# Create a mock user to return from create
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.create.return_value = mock_user
|
||||
|
||||
# Call init_db
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify user was created
|
||||
assert user is not None
|
||||
assert user.email == "admin@test.com"
|
||||
assert user.is_superuser is True
|
||||
mock_crud.create.assert_called_once()
|
||||
|
||||
def test_init_db_returns_existing_superuser(self, db_session, monkeypatch):
|
||||
"""Test that init_db returns existing superuser without creating new one"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "existing@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud to return existing user
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
existing_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="existing@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Existing",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.get_by_email.return_value = existing_user
|
||||
|
||||
# Call init_db
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify existing user was returned
|
||||
assert user is not None
|
||||
assert user.email == "existing@test.com"
|
||||
# create should NOT be called
|
||||
mock_crud.create.assert_not_called()
|
||||
|
||||
def test_init_db_uses_defaults_when_env_not_set(self, db_session):
|
||||
"""Test that init_db uses default credentials when env vars not set"""
|
||||
# Mock settings to return None for superuser credentials
|
||||
with patch('app.init_db.settings') as mock_settings:
|
||||
mock_settings.FIRST_SUPERUSER_EMAIL = None
|
||||
mock_settings.FIRST_SUPERUSER_PASSWORD = None
|
||||
|
||||
# Mock user_crud
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.create.return_value = mock_user
|
||||
|
||||
# Call init_db
|
||||
with patch('app.init_db.logger') as mock_logger:
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify default email was used
|
||||
mock_crud.get_by_email.assert_called_with(db_session, email="admin@example.com")
|
||||
# Verify warning was logged since credentials not set
|
||||
assert mock_logger.warning.called
|
||||
|
||||
def test_init_db_handles_creation_error(self, db_session, monkeypatch):
|
||||
"""Test that init_db handles errors during user creation"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud to raise an exception
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
mock_crud.create.side_effect = Exception("Database error")
|
||||
|
||||
# Call init_db and expect exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
init_db(db_session)
|
||||
|
||||
assert "Database error" in str(exc_info.value)
|
||||
|
||||
def test_init_db_logs_superuser_creation(self, db_session, monkeypatch):
|
||||
"""Test that init_db logs appropriate messages"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.create.return_value = mock_user
|
||||
|
||||
# Call init_db with logger mock
|
||||
with patch('app.init_db.logger') as mock_logger:
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify info log was called
|
||||
assert mock_logger.info.called
|
||||
info_call_args = str(mock_logger.info.call_args)
|
||||
assert "Created first superuser" in info_call_args
|
||||
|
||||
def test_init_db_logs_existing_user(self, db_session, monkeypatch):
|
||||
"""Test that init_db logs when user already exists"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "existing@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud to return existing user
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
existing_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="existing@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Existing",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.get_by_email.return_value = existing_user
|
||||
|
||||
# Call init_db with logger mock
|
||||
with patch('app.init_db.logger') as mock_logger:
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify info log was called
|
||||
assert mock_logger.info.called
|
||||
info_call_args = str(mock_logger.info.call_args)
|
||||
assert "already exists" in info_call_args.lower()
|
||||
0
backend/tests/utils/__init__.py
Normal file
0
backend/tests/utils/__init__.py
Normal file
233
backend/tests/utils/test_security.py
Normal file
233
backend/tests/utils/test_security.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# tests/utils/test_security.py
|
||||
"""
|
||||
Tests for security utility functions.
|
||||
"""
|
||||
import time
|
||||
import base64
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.utils.security import create_upload_token, verify_upload_token
|
||||
|
||||
|
||||
class TestCreateUploadToken:
|
||||
"""Tests for create_upload_token function."""
|
||||
|
||||
def test_create_upload_token_basic(self):
|
||||
"""Test basic token creation."""
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
# Token should be base64 encoded
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
assert "payload" in token_data
|
||||
assert "signature" in token_data
|
||||
except Exception as e:
|
||||
pytest.fail(f"Token is not properly formatted: {e}")
|
||||
|
||||
def test_create_upload_token_contains_correct_payload(self):
|
||||
"""Test that token contains correct payload data."""
|
||||
file_path = "/uploads/avatar.jpg"
|
||||
content_type = "image/jpeg"
|
||||
|
||||
token = create_upload_token(file_path, content_type)
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
assert payload["path"] == file_path
|
||||
assert payload["content_type"] == content_type
|
||||
assert "exp" in payload
|
||||
assert "nonce" in payload
|
||||
|
||||
def test_create_upload_token_default_expiration(self):
|
||||
"""Test that default expiration is 300 seconds (5 minutes)."""
|
||||
before = int(time.time())
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
# Expiration should be around current time + 300 seconds
|
||||
exp_time = payload["exp"]
|
||||
assert before + 300 <= exp_time <= after + 300
|
||||
|
||||
def test_create_upload_token_custom_expiration(self):
|
||||
"""Test token creation with custom expiration time."""
|
||||
custom_exp = 600 # 10 minutes
|
||||
before = int(time.time())
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp)
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
# Expiration should be around current time + custom_exp seconds
|
||||
exp_time = payload["exp"]
|
||||
assert before + custom_exp <= exp_time <= after + custom_exp
|
||||
|
||||
def test_create_upload_token_unique_nonces(self):
|
||||
"""Test that each token has a unique nonce."""
|
||||
token1 = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
token2 = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode both tokens
|
||||
decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8'))
|
||||
token_data1 = json.loads(decoded1)
|
||||
nonce1 = token_data1["payload"]["nonce"]
|
||||
|
||||
decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8'))
|
||||
token_data2 = json.loads(decoded2)
|
||||
nonce2 = token_data2["payload"]["nonce"]
|
||||
|
||||
# Nonces should be different
|
||||
assert nonce1 != nonce2
|
||||
|
||||
def test_create_upload_token_different_paths(self):
|
||||
"""Test that tokens for different paths are different."""
|
||||
token1 = create_upload_token("/uploads/file1.jpg", "image/jpeg")
|
||||
token2 = create_upload_token("/uploads/file2.jpg", "image/jpeg")
|
||||
|
||||
assert token1 != token2
|
||||
|
||||
|
||||
class TestVerifyUploadToken:
|
||||
"""Tests for verify_upload_token function."""
|
||||
|
||||
def test_verify_valid_token(self):
|
||||
"""Test verification of a valid token."""
|
||||
file_path = "/uploads/test.jpg"
|
||||
content_type = "image/jpeg"
|
||||
|
||||
token = create_upload_token(file_path, content_type)
|
||||
payload = verify_upload_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["path"] == file_path
|
||||
assert payload["content_type"] == content_type
|
||||
|
||||
def test_verify_expired_token(self):
|
||||
"""Test that expired tokens are rejected."""
|
||||
# Create a mock time module
|
||||
mock_time = MagicMock()
|
||||
current_time = 1000000
|
||||
mock_time.time = MagicMock(return_value=current_time)
|
||||
|
||||
with patch('app.utils.security.time', mock_time):
|
||||
# Create token that "expires" at current_time + 1
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1)
|
||||
|
||||
# Now set time to after expiration
|
||||
mock_time.time.return_value = current_time + 2
|
||||
|
||||
# Token should be expired
|
||||
payload = verify_upload_token(token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_invalid_signature(self):
|
||||
"""Test that tokens with invalid signatures are rejected."""
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["signature"] = "invalid_signature"
|
||||
|
||||
# Re-encode the tampered token
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_tampered_payload(self):
|
||||
"""Test that tokens with tampered payloads are rejected."""
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify payload, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["path"] = "/uploads/hacked.exe"
|
||||
|
||||
# Re-encode the tampered token (signature won't match)
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_malformed_token(self):
|
||||
"""Test that malformed tokens return None."""
|
||||
# Test various malformed tokens
|
||||
invalid_tokens = [
|
||||
"not_a_valid_token",
|
||||
"SGVsbG8gV29ybGQ=", # Valid base64 but not a token
|
||||
"",
|
||||
" ",
|
||||
]
|
||||
|
||||
for invalid_token in invalid_tokens:
|
||||
payload = verify_upload_token(invalid_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_invalid_json(self):
|
||||
"""Test that tokens with invalid JSON are rejected."""
|
||||
# Create a base64 string that decodes to invalid JSON
|
||||
invalid_json = "not valid json"
|
||||
invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(invalid_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_missing_fields(self):
|
||||
"""Test that tokens missing required fields are rejected."""
|
||||
# Create a token-like structure but missing required fields
|
||||
incomplete_data = {
|
||||
"payload": {
|
||||
"path": "/uploads/test.jpg"
|
||||
# Missing content_type, exp, nonce
|
||||
},
|
||||
"signature": "some_signature"
|
||||
}
|
||||
|
||||
incomplete_json = json.dumps(incomplete_data)
|
||||
incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(incomplete_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_token_round_trip(self):
|
||||
"""Test creating and verifying a token in sequence."""
|
||||
test_cases = [
|
||||
("/uploads/image.jpg", "image/jpeg", 300),
|
||||
("/uploads/document.pdf", "application/pdf", 600),
|
||||
("/uploads/video.mp4", "video/mp4", 900),
|
||||
]
|
||||
|
||||
for file_path, content_type, expires_in in test_cases:
|
||||
token = create_upload_token(file_path, content_type, expires_in)
|
||||
payload = verify_upload_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["path"] == file_path
|
||||
assert payload["content_type"] == content_type
|
||||
assert "exp" in payload
|
||||
assert "nonce" in payload
|
||||
|
||||
# Note: test_verify_token_cannot_be_reused_with_different_secret removed
|
||||
# The signature validation is already tested by test_verify_invalid_signature
|
||||
# and test_verify_tampered_payload. Testing with different SECRET_KEY
|
||||
# requires complex mocking that can interfere with other tests.
|
||||
6072
frontend/package-lock.json
generated
Normal file
6072
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user