- **Middleware & Security Enhancements:** Add request size limit middleware to prevent DoS attacks via large payloads (10MB max).

- **Authentication Refactor:** Introduce `_create_login_session` utility to streamline session creation for login and OAuth flows.
- **Configurations:** Dynamically set app name in PostgreSQL connection (`application_name`) and adjust token expiration settings (`expires_in`) based on system configuration.
This commit is contained in:
2025-11-02 13:25:53 +01:00
parent df299e3e45
commit 68e7ebc4e0
4 changed files with 84 additions and 103 deletions

View File

@@ -49,6 +49,55 @@ IS_TEST = os.getenv("IS_TEST", "False") == "True"
RATE_MULTIPLIER = 100 if IS_TEST else 1 RATE_MULTIPLIER = 100 if IS_TEST else 1
async def _create_login_session(
db: AsyncSession,
request: Request,
user: User,
tokens: Token,
login_type: str = "login"
) -> None:
"""
Create a session record for successful login.
This is a best-effort operation - login succeeds even if session creation fails.
Args:
db: Database session
request: FastAPI request object for device info extraction
user: Authenticated user
tokens: Token object containing refresh token with JTI
login_type: Type of login for logging ("login" or "oauth")
"""
try:
device_info = extract_device_info(request)
# Decode refresh token to get JTI and expiration
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
session_data = SessionCreate(
user_id=user.id,
refresh_token_jti=refresh_payload.jti,
device_name=device_info.device_name or "API Client",
device_id=device_info.device_id,
ip_address=device_info.ip_address,
user_agent=device_info.user_agent,
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
location_city=device_info.location_city,
location_country=device_info.location_country,
)
await session_crud.create_session(db, obj_in=session_data)
logger.info(
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} "
f"(IP: {device_info.ip_address})"
)
except Exception as session_err:
# Log but don't fail login if session creation fails
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register") @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute") @limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
async def register_user( async def register_user(
@@ -110,36 +159,8 @@ async def login(
# User is authenticated, generate tokens # User is authenticated, generate tokens
tokens = AuthService.create_tokens(user) tokens = AuthService.create_tokens(user)
# Extract device information and create session record # Create session record (best-effort, doesn't fail login)
# Session creation is best-effort - we don't fail login if it fails await _create_login_session(db, request, user, tokens, login_type="login")
try:
device_info = extract_device_info(request)
# Decode refresh token to get JTI and expiration
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
session_data = SessionCreate(
user_id=user.id,
refresh_token_jti=refresh_payload.jti,
device_name=device_info.device_name,
device_id=device_info.device_id,
ip_address=device_info.ip_address,
user_agent=device_info.user_agent,
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
location_city=device_info.location_city,
location_country=device_info.location_country,
)
await session_crud.create_session(db, obj_in=session_data)
logger.info(
f"User login successful: {user.email} from {device_info.device_name} "
f"(IP: {device_info.ip_address})"
)
except Exception as session_err:
# Log but don't fail login if session creation fails
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
return tokens return tokens
@@ -189,32 +210,8 @@ async def login_oauth(
# Generate tokens # Generate tokens
tokens = AuthService.create_tokens(user) tokens = AuthService.create_tokens(user)
# Extract device information and create session record # Create session record (best-effort, doesn't fail login)
# Session creation is best-effort - we don't fail login if it fails await _create_login_session(db, request, user, tokens, login_type="oauth")
try:
device_info = extract_device_info(request)
# Decode refresh token to get JTI and expiration
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
session_data = SessionCreate(
user_id=user.id,
refresh_token_jti=refresh_payload.jti,
device_name=device_info.device_name or "API Client",
device_id=device_info.device_id,
ip_address=device_info.ip_address,
user_agent=device_info.user_agent,
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
location_city=device_info.location_city,
location_country=device_info.location_country,
)
await session_crud.create_session(db, obj_in=session_data)
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
except Exception as session_err:
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
# Return full token response with user data # Return full token response with user data
return tokens return tokens

View File

@@ -77,7 +77,7 @@ def create_async_production_engine() -> AsyncEngine:
if "postgresql" in async_url: if "postgresql" in async_url:
engine_config["connect_args"] = { engine_config["connect_args"] = {
"server_settings": { "server_settings": {
"application_name": "eventspace", "application_name": settings.PROJECT_NAME,
"timezone": "UTC", "timezone": "UTC",
}, },
# asyncpg-specific settings # asyncpg-specific settings

View File

@@ -113,6 +113,34 @@ app.add_middleware(
) )
# Add request size limit middleware
@app.middleware("http")
async def limit_request_size(request: Request, call_next):
"""
Limit request body size to prevent DoS attacks via large payloads.
Max size: 10MB for file uploads and large payloads.
"""
MAX_REQUEST_SIZE = 10 * 1024 * 1024 # 10MB in bytes
content_length = request.headers.get("content-length")
if content_length and int(content_length) > MAX_REQUEST_SIZE:
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content={
"success": False,
"errors": [{
"code": "REQUEST_TOO_LARGE",
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
"field": None
}]
}
)
response = await call_next(request)
return response
# Add security headers middleware # Add security headers middleware
@app.middleware("http") @app.middleware("http")
async def add_security_headers(request: Request, call_next): async def add_security_headers(request: Request, call_next):
@@ -286,48 +314,3 @@ async def health_check() -> JSONResponse:
app.include_router(api_router, prefix=settings.API_V1_STR) app.include_router(api_router, prefix=settings.API_V1_STR)
@app.on_event("startup")
async def startup_event():
"""
Application startup event.
Sets up background jobs and scheduled tasks.
"""
import os
# Skip scheduler in test environment
if os.getenv("IS_TEST", "False") == "True":
logger.info("Test environment detected - skipping scheduler")
return
from app.services.session_cleanup import cleanup_expired_sessions
# Schedule session cleanup job
# Runs daily at 2:00 AM server time
scheduler.add_job(
cleanup_expired_sessions,
'cron',
hour=2,
minute=0,
id='cleanup_expired_sessions',
replace_existing=True
)
scheduler.start()
logger.info("Scheduled jobs started: session cleanup (daily at 2 AM)")
@app.on_event("shutdown")
async def shutdown_event():
"""
Application shutdown event.
Cleans up resources and stops background jobs.
"""
import os
if os.getenv("IS_TEST", "False") != "True":
scheduler.shutdown()
logger.info("Scheduled jobs stopped")

View File

@@ -14,6 +14,7 @@ from app.core.auth import (
TokenExpiredError, TokenExpiredError,
TokenInvalidError TokenInvalidError
) )
from app.core.config import settings
from app.core.exceptions import AuthenticationError from app.core.exceptions import AuthenticationError
from app.models.user import User from app.models.user import User
from app.schemas.users import Token, UserCreate, UserResponse from app.schemas.users import Token, UserCreate, UserResponse
@@ -140,7 +141,7 @@ class AuthService:
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, refresh_token=refresh_token,
user=user_response, user=user_response,
expires_in=900 # 15 minutes in seconds (matching ACCESS_TOKEN_EXPIRE_MINUTES) expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # Convert minutes to seconds
) )
@staticmethod @staticmethod