Refactor database module and add testing utilities

Simplify database module by re-organizing engine creation, session handling, and removing redundant methods. Introduce SQLite compatibility for testing and add a utility module for test database setup and teardown. Integrate initial unit tests for user models and update dependencies for security and testing.
This commit is contained in:
2025-02-28 12:31:10 +01:00
parent 5cd38c82e0
commit 5f9a63dd07
7 changed files with 140 additions and 124 deletions

View File

@@ -14,7 +14,16 @@ class Settings(BaseSettings):
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "eventspace"
DATABASE_URL: Optional[str] = None
db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
# SQL debugging (disable in production)
sql_echo: bool = False # Log SQL statements
sql_echo_pool: bool = False # Log connection pool events
sql_echo_timing: bool = False # Log query execution times
slow_query_threshold: float = 0.5 # Log queries taking longer than this
@property
def database_url(self) -> str:
"""

View File

@@ -1,142 +1,60 @@
import time
# app/core/database.py
import logging
from contextlib import contextmanager
from typing import Generator, Any
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.exc import SQLAlchemyError, DBAPIError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import JSONB, UUID
from app.core.config import settings
# Configure logging
logger = logging.getLogger(__name__)
# PostgreSQL-specific engine configuration
engine = create_engine(
settings.database_url,
# Connection pool settings
pool_size=settings.db_pool_size, # Default number of connections to maintain
max_overflow=settings.db_max_overflow, # Max extra connections when pool is fully used
pool_timeout=settings.db_pool_timeout, # Seconds to wait before giving up on getting a connection
pool_recycle=settings.db_pool_recycle, # Seconds after which a connection is recycled
pool_pre_ping=True, # Test connections for liveness before using them
# Query execution settings
connect_args={
"application_name": "eventspace", # Helps identify app in PostgreSQL logs
"keepalives": 1, # Enable TCP keepalive
"keepalives_idle": 60, # Seconds before sending keepalive probes
"keepalives_interval": 10, # Seconds between keepalive probes
"keepalives_count": 5, # Number of probes before dropping connection
"options": "-c timezone=UTC", # Set timezone to UTC for consistency
},
# Performance tuning
isolation_level="READ COMMITTED", # Default isolation level for transactions
echo=settings.sql_echo, # Log SQL statements for debugging if enabled
echo_pool=settings.sql_echo_pool, # Log pool events for debugging if enabled
)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
# SQLAlchemy session factory
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
expire_on_commit=False # Prevents additional DB queries after commit
)
@compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
# Declarative base for models
Base = declarative_base()
# Create engine with optimized settings for PostgreSQL
def create_production_engine():
return create_engine(
settings.database_url,
# Connection pool settings
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow,
pool_timeout=settings.db_pool_timeout,
pool_recycle=settings.db_pool_recycle,
pool_pre_ping=True,
# Query execution settings
connect_args={
"application_name": "eventspace",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
"options": "-c timezone=UTC",
},
isolation_level="READ COMMITTED",
echo=settings.sql_echo,
echo_pool=settings.sql_echo_pool,
)
# Add performance metrics
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
conn.info.setdefault("query_start_time", []).append(time.time())
if settings.sql_echo_timing:
logger.debug("Start Query: %s", statement)
# Default production engine and session factory
engine = create_production_engine()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - conn.info["query_start_time"].pop(-1)
if settings.sql_echo_timing:
logger.debug("Query Complete in %.3f seconds: %s", total, statement)
if total > settings.slow_query_threshold:
logger.warning("Slow Query (%.3f seconds): %s", total, statement)
# Database health check
def check_database_connection() -> bool:
"""Verify database connection is working properly."""
try:
# Execute a simple query
with engine.connect() as connection:
connection.execute("SELECT 1")
return True
except SQLAlchemyError as e:
logger.error(f"Database connection check failed: {str(e)}")
return False
# FastAPI dependency to get DB session
def get_db() -> Generator[Session, None, None]:
"""Dependency that provides a database session.
Usage:
@app.get("/items/")
def read_items(db: Session = Depends(get_db)):
return db.query(Item).all()
"""
# FastAPI dependency
def get_db():
db = SessionLocal()
try:
yield db
except DBAPIError as e:
logger.error(f"Database error during request: {str(e)}")
db.rollback() # Rollback in case of error
raise
finally:
db.close()
# Context manager for handling transactions
@contextmanager
def get_db_transaction() -> Generator[Session, None, None]:
"""Context manager for database transactions.
Usage:
with get_db_transaction() as db:
db.add(obj)
# Will automatically commit or rollback
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception as e:
db.rollback()
logger.error(f"Transaction error: {str(e)}")
raise
finally:
db.close()
# Function to initialize database connections at startup
def init_db() -> None:
"""Initialize database connections pool at application startup."""
logger.info("Initializing database connection pool")
try:
with engine.connect() as conn:
conn.execute("SELECT 1")
logger.info("Database connection successful")
except SQLAlchemyError as e:
logger.error(f"Database initialization failed: {str(e)}")
raise
# Function to dispose of connections at shutdown
def close_db() -> None:
"""Close all database connections at application shutdown."""
logger.info("Closing database connections")
engine.dispose()
db.close()

View File

View File

@@ -0,0 +1,45 @@
import logging
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker, clear_mappers
from sqlalchemy.pool import StaticPool
from app.core.database import Base
logger = logging.getLogger(__name__)
def get_test_engine():
"""Create an SQLite in-memory engine specifically for testing"""
test_engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
)
return test_engine
def setup_test_db():
"""Create a test database and session factory"""
# Create a new engine for this test run
test_engine = get_test_engine()
# Create tables
Base.metadata.create_all(test_engine)
# Create session factory
TestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
expire_on_commit=False
)
return test_engine, TestingSessionLocal
def teardown_test_db(engine):
"""Clean up after tests"""
# Drop all tables
Base.metadata.drop_all(engine)
# Dispose of engine
engine.dispose()

9
backend/pytest.ini Normal file
View File

@@ -0,0 +1,9 @@
[pytest]
env =
IS_TEST=True
testpaths = tests
python_files = test_*.py
addopts = --disable-warnings
markers =
sqlite: marks tests that should run on SQLite (mocked).
postgres: marks tests that require a real PostgreSQL database.

View File

@@ -41,4 +41,10 @@ requests>=2.32.0
black>=24.3.0
isort>=5.13.2
flake8>=7.0.0
mypy>=1.8.0
mypy>=1.8.0
# Security
python-jose==3.4.0
bcrypt==4.2.1
cryptography==44.0.1
passlib==1.7.4

View File

@@ -0,0 +1,29 @@
# tests/models/test_user.py
import uuid
from app.models.user import User
def test_create_user(db_session):
# Arrange
user_id = uuid.uuid4()
new_user = User(
id=user_id,
email="test@example.com",
password_hash="hashedpassword",
first_name="Test",
last_name="User",
phone_number="1234567890",
is_active=True,
is_superuser=False,
preferences={"theme": "dark"},
)
db_session.add(new_user)
# Act
db_session.commit()
created_user = db_session.query(User).filter_by(email="test@example.com").first()
# Assert
assert created_user is not None
assert created_user.email == "test@example.com"
assert created_user.first_name == "Test"
assert created_user.preferences == {"theme": "dark"}