Files
pragma-stack/backend/app/repositories/base.py
Felipe Cardoso a3f78dc801 refactor(tests): replace crud references with repo across repository test files
- Updated import statements and test logic to align with `repositories` naming changes.
- Adjusted documentation and test names for consistency with the updated naming convention.
- Improved test descriptions to reflect the repository-based structure.
2026-03-01 19:22:16 +01:00

421 lines
15 KiB
Python

# app/repositories/base.py
"""
Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
import logging
import uuid
from datetime import UTC
from typing import Any, TypeVar
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
from app.core.database import Base
from app.core.repository_exceptions import (
DuplicateEntryError,
IntegrityConstraintError,
InvalidInputError,
)
logger = logging.getLogger(__name__)
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class BaseRepository[
ModelType: Base,
CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel,
]:
"""Async repository operations for a model."""
def __init__(self, model: type[ModelType]):
"""
Repository object with default async methods to Create, Read, Update, Delete.
Parameters:
model: A SQLAlchemy model class
"""
self.model = model
async def get(
self, db: AsyncSession, id: str, options: list[Load] | None = None
) -> ModelType | None:
"""
Get a single record by ID with UUID validation and optional eager loading.
Args:
db: Database session
id: Record UUID
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
for eager loading relationships to prevent N+1 queries
Returns:
Model instance or None if not found
"""
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning("Invalid UUID format: %s - %s", id, e)
return None
try:
query = select(self.model).where(self.model.id == uuid_obj)
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(
"Error retrieving %s with id %s: %s", self.model.__name__, id, e
)
raise
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
options: list[Load] | None = None,
) -> list[ModelType]:
"""
Get multiple records with pagination validation and optional eager loading.
"""
if skip < 0:
raise InvalidInputError("skip must be non-negative")
if limit < 0:
raise InvalidInputError("limit must be non-negative")
if limit > 1000:
raise InvalidInputError("Maximum limit is 1000")
try:
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(
"Error retrieving multiple %s records: %s", self.model.__name__, e
)
raise
async def create(
self, db: AsyncSession, *, obj_in: CreateSchemaType
) -> ModelType: # pragma: no cover
"""Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice.
All repository subclasses override this method with their own implementations.
Marked as pragma: no cover to avoid false coverage gaps.
"""
try: # pragma: no cover
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(
"Duplicate entry attempted for %s: %s",
self.model.__name__,
error_msg,
)
raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(
"Integrity error creating %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback()
logger.error("Database error creating %s: %s", self.model.__name__, e)
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Unexpected error creating %s: %s", self.model.__name__, e)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType | dict[str, Any],
) -> ModelType:
"""Update a record with error handling."""
try:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(
"Duplicate entry attempted for %s: %s",
self.model.__name__,
error_msg,
)
raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(
"Integrity error updating %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error("Database error updating %s: %s", self.model.__name__, e)
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e:
await db.rollback()
logger.exception("Unexpected error updating %s: %s", self.model.__name__, e)
raise
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check."""
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning("Invalid UUID format for deletion: %s - %s", id, e)
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(
"%s with id %s not found for deletion", self.model.__name__, id
)
return None
await db.delete(obj)
await db.commit()
return obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(
"Integrity error deleting %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(
f"Cannot delete {self.model.__name__}: referenced by other records"
)
except Exception as e:
await db.rollback()
logger.exception(
"Error deleting %s with id %s: %s", self.model.__name__, id, e
)
raise
async def get_multi_with_total(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: str | None = None,
sort_order: str = "asc",
filters: dict[str, Any] | None = None,
) -> tuple[list[ModelType], int]: # pragma: no cover
"""
Get multiple records with total count, filtering, and sorting.
NOTE: This method is defensive code that's never called in practice.
All repository subclasses override this method with their own implementations.
Marked as pragma: no cover to avoid false coverage gaps.
"""
if skip < 0:
raise InvalidInputError("skip must be non-negative")
if limit < 0:
raise InvalidInputError("limit must be non-negative")
if limit > 1000:
raise InvalidInputError("Maximum limit is 1000")
try:
query = select(self.model)
if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None))
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value)
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
else:
query = query.order_by(self.model.id)
query = query.offset(skip).limit(limit)
items_result = await db.execute(query)
items = list(items_result.scalars().all())
return items, total
except Exception as e: # pragma: no cover
logger.error(
"Error retrieving paginated %s records: %s", self.model.__name__, e
)
raise
async def count(self, db: AsyncSession) -> int:
"""Get total count of records."""
try:
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error("Error counting %s records: %s", self.model.__name__, e)
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
"""Check if a record exists by ID."""
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning("Invalid UUID format for soft deletion: %s - %s", id, e)
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(
"%s with id %s not found for soft deletion", self.model.__name__, id
)
return None
if not hasattr(self.model, "deleted_at"):
logger.error("%s does not support soft deletes", self.model.__name__)
raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column"
)
obj.deleted_at = datetime.now(UTC)
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.exception(
"Error soft deleting %s with id %s: %s", self.model.__name__, id, e
)
raise
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning("Invalid UUID format for restoration: %s - %s", id, e)
return None
try:
if hasattr(self.model, "deleted_at"):
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error("%s does not support soft deletes", self.model.__name__)
raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column"
)
if obj is None:
logger.warning(
"Soft-deleted %s with id %s not found for restoration",
self.model.__name__,
id,
)
return None
obj.deleted_at = None
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.exception(
"Error restoring %s with id %s: %s", self.model.__name__, id, e
)
raise