Refactor database module and add testing utilities
Simplify database module by re-organizing engine creation, session handling, and removing redundant methods. Introduce SQLite compatibility for testing and add a utility module for test database setup and teardown. Integrate initial unit tests for user models and update dependencies for security and testing.
This commit is contained in:
@@ -14,7 +14,16 @@ class Settings(BaseSettings):
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "eventspace"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||
|
||||
# SQL debugging (disable in production)
|
||||
sql_echo: bool = False # Log SQL statements
|
||||
sql_echo_pool: bool = False # Log connection pool events
|
||||
sql_echo_timing: bool = False # Log query execution times
|
||||
slow_query_threshold: float = 0.5 # Log queries taking longer than this
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -1,142 +1,60 @@
|
||||
import time
|
||||
# app/core/database.py
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Any
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.exc import SQLAlchemyError, DBAPIError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PostgreSQL-specific engine configuration
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
# Connection pool settings
|
||||
pool_size=settings.db_pool_size, # Default number of connections to maintain
|
||||
max_overflow=settings.db_max_overflow, # Max extra connections when pool is fully used
|
||||
pool_timeout=settings.db_pool_timeout, # Seconds to wait before giving up on getting a connection
|
||||
pool_recycle=settings.db_pool_recycle, # Seconds after which a connection is recycled
|
||||
pool_pre_ping=True, # Test connections for liveness before using them
|
||||
# Query execution settings
|
||||
connect_args={
|
||||
"application_name": "eventspace", # Helps identify app in PostgreSQL logs
|
||||
"keepalives": 1, # Enable TCP keepalive
|
||||
"keepalives_idle": 60, # Seconds before sending keepalive probes
|
||||
"keepalives_interval": 10, # Seconds between keepalive probes
|
||||
"keepalives_count": 5, # Number of probes before dropping connection
|
||||
"options": "-c timezone=UTC", # Set timezone to UTC for consistency
|
||||
},
|
||||
# Performance tuning
|
||||
isolation_level="READ COMMITTED", # Default isolation level for transactions
|
||||
echo=settings.sql_echo, # Log SQL statements for debugging if enabled
|
||||
echo_pool=settings.sql_echo_pool, # Log pool events for debugging if enabled
|
||||
)
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
# SQLAlchemy session factory
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
expire_on_commit=False # Prevents additional DB queries after commit
|
||||
)
|
||||
@compiles(UUID, 'sqlite')
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
# Declarative base for models
|
||||
Base = declarative_base()
|
||||
|
||||
# Create engine with optimized settings for PostgreSQL
|
||||
def create_production_engine():
|
||||
return create_engine(
|
||||
settings.database_url,
|
||||
# Connection pool settings
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow,
|
||||
pool_timeout=settings.db_pool_timeout,
|
||||
pool_recycle=settings.db_pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
# Query execution settings
|
||||
connect_args={
|
||||
"application_name": "eventspace",
|
||||
"keepalives": 1,
|
||||
"keepalives_idle": 60,
|
||||
"keepalives_interval": 10,
|
||||
"keepalives_count": 5,
|
||||
"options": "-c timezone=UTC",
|
||||
},
|
||||
isolation_level="READ COMMITTED",
|
||||
echo=settings.sql_echo,
|
||||
echo_pool=settings.sql_echo_pool,
|
||||
)
|
||||
|
||||
# Add performance metrics
|
||||
@event.listens_for(Engine, "before_cursor_execute")
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
conn.info.setdefault("query_start_time", []).append(time.time())
|
||||
if settings.sql_echo_timing:
|
||||
logger.debug("Start Query: %s", statement)
|
||||
# Default production engine and session factory
|
||||
engine = create_production_engine()
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@event.listens_for(Engine, "after_cursor_execute")
|
||||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
total = time.time() - conn.info["query_start_time"].pop(-1)
|
||||
if settings.sql_echo_timing:
|
||||
logger.debug("Query Complete in %.3f seconds: %s", total, statement)
|
||||
if total > settings.slow_query_threshold:
|
||||
logger.warning("Slow Query (%.3f seconds): %s", total, statement)
|
||||
|
||||
|
||||
# Database health check
|
||||
def check_database_connection() -> bool:
|
||||
"""Verify database connection is working properly."""
|
||||
try:
|
||||
# Execute a simple query
|
||||
with engine.connect() as connection:
|
||||
connection.execute("SELECT 1")
|
||||
return True
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database connection check failed: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# FastAPI dependency to get DB session
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""Dependency that provides a database session.
|
||||
|
||||
Usage:
|
||||
@app.get("/items/")
|
||||
def read_items(db: Session = Depends(get_db)):
|
||||
return db.query(Item).all()
|
||||
"""
|
||||
# FastAPI dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
except DBAPIError as e:
|
||||
logger.error(f"Database error during request: {str(e)}")
|
||||
db.rollback() # Rollback in case of error
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# Context manager for handling transactions
|
||||
@contextmanager
|
||||
def get_db_transaction() -> Generator[Session, None, None]:
|
||||
"""Context manager for database transactions.
|
||||
|
||||
Usage:
|
||||
with get_db_transaction() as db:
|
||||
db.add(obj)
|
||||
# Will automatically commit or rollback
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Transaction error: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# Function to initialize database connections at startup
|
||||
def init_db() -> None:
|
||||
"""Initialize database connections pool at application startup."""
|
||||
logger.info("Initializing database connection pool")
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
logger.info("Database connection successful")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database initialization failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Function to dispose of connections at shutdown
|
||||
def close_db() -> None:
|
||||
"""Close all database connections at application shutdown."""
|
||||
logger.info("Closing database connections")
|
||||
engine.dispose()
|
||||
db.close()
|
||||
0
backend/app/utils/__init__.py
Normal file
0
backend/app/utils/__init__.py
Normal file
45
backend/app/utils/test_utils.py
Normal file
45
backend/app/utils/test_utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker, clear_mappers
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_test_engine():
|
||||
"""Create an SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
)
|
||||
|
||||
return test_engine
|
||||
|
||||
def setup_test_db():
|
||||
"""Create a test database and session factory"""
|
||||
# Create a new engine for this test run
|
||||
test_engine = get_test_engine()
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(test_engine)
|
||||
|
||||
# Create session factory
|
||||
TestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
return test_engine, TestingSessionLocal
|
||||
|
||||
def teardown_test_db(engine):
|
||||
"""Clean up after tests"""
|
||||
# Drop all tables
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
# Dispose of engine
|
||||
engine.dispose()
|
||||
9
backend/pytest.ini
Normal file
9
backend/pytest.ini
Normal file
@@ -0,0 +1,9 @@
|
||||
[pytest]
|
||||
env =
|
||||
IS_TEST=True
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
addopts = --disable-warnings
|
||||
markers =
|
||||
sqlite: marks tests that should run on SQLite (mocked).
|
||||
postgres: marks tests that require a real PostgreSQL database.
|
||||
@@ -41,4 +41,10 @@ requests>=2.32.0
|
||||
black>=24.3.0
|
||||
isort>=5.13.2
|
||||
flake8>=7.0.0
|
||||
mypy>=1.8.0
|
||||
mypy>=1.8.0
|
||||
|
||||
# Security
|
||||
python-jose==3.4.0
|
||||
bcrypt==4.2.1
|
||||
cryptography==44.0.1
|
||||
passlib==1.7.4
|
||||
29
backend/tests/models/test_user.py
Normal file
29
backend/tests/models/test_user.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# tests/models/test_user.py
|
||||
import uuid
|
||||
from app.models.user import User
|
||||
|
||||
def test_create_user(db_session):
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
new_user = User(
|
||||
id=user_id,
|
||||
email="test@example.com",
|
||||
password_hash="hashedpassword",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="1234567890",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
preferences={"theme": "dark"},
|
||||
)
|
||||
db_session.add(new_user)
|
||||
|
||||
# Act
|
||||
db_session.commit()
|
||||
created_user = db_session.query(User).filter_by(email="test@example.com").first()
|
||||
|
||||
# Assert
|
||||
assert created_user is not None
|
||||
assert created_user.email == "test@example.com"
|
||||
assert created_user.first_name == "Test"
|
||||
assert created_user.preferences == {"theme": "dark"}
|
||||
Reference in New Issue
Block a user