- **Middleware & Security Enhancements:** Add request size limit middleware to prevent DoS attacks via large payloads (10MB max).
- **Authentication Refactor:** Introduce `_create_login_session` utility to streamline session creation for login and OAuth flows. - **Configurations:** Dynamically set app name in PostgreSQL connection (`application_name`) and adjust token expiration settings (`expires_in`) based on system configuration.
This commit is contained in:
@@ -49,6 +49,55 @@ IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
|||||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_login_session(
|
||||||
|
db: AsyncSession,
|
||||||
|
request: Request,
|
||||||
|
user: User,
|
||||||
|
tokens: Token,
|
||||||
|
login_type: str = "login"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create a session record for successful login.
|
||||||
|
|
||||||
|
This is a best-effort operation - login succeeds even if session creation fails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
request: FastAPI request object for device info extraction
|
||||||
|
user: Authenticated user
|
||||||
|
tokens: Token object containing refresh token with JTI
|
||||||
|
login_type: Type of login for logging ("login" or "oauth")
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
device_info = extract_device_info(request)
|
||||||
|
|
||||||
|
# Decode refresh token to get JTI and expiration
|
||||||
|
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||||
|
|
||||||
|
session_data = SessionCreate(
|
||||||
|
user_id=user.id,
|
||||||
|
refresh_token_jti=refresh_payload.jti,
|
||||||
|
device_name=device_info.device_name or "API Client",
|
||||||
|
device_id=device_info.device_id,
|
||||||
|
ip_address=device_info.ip_address,
|
||||||
|
user_agent=device_info.user_agent,
|
||||||
|
last_used_at=datetime.now(timezone.utc),
|
||||||
|
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||||
|
location_city=device_info.location_city,
|
||||||
|
location_country=device_info.location_country,
|
||||||
|
)
|
||||||
|
|
||||||
|
await session_crud.create_session(db, obj_in=session_data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} "
|
||||||
|
f"(IP: {device_info.ip_address})"
|
||||||
|
)
|
||||||
|
except Exception as session_err:
|
||||||
|
# Log but don't fail login if session creation fails
|
||||||
|
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
||||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||||
async def register_user(
|
async def register_user(
|
||||||
@@ -110,36 +159,8 @@ async def login(
|
|||||||
# User is authenticated, generate tokens
|
# User is authenticated, generate tokens
|
||||||
tokens = AuthService.create_tokens(user)
|
tokens = AuthService.create_tokens(user)
|
||||||
|
|
||||||
# Extract device information and create session record
|
# Create session record (best-effort, doesn't fail login)
|
||||||
# Session creation is best-effort - we don't fail login if it fails
|
await _create_login_session(db, request, user, tokens, login_type="login")
|
||||||
try:
|
|
||||||
device_info = extract_device_info(request)
|
|
||||||
|
|
||||||
# Decode refresh token to get JTI and expiration
|
|
||||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
|
||||||
|
|
||||||
session_data = SessionCreate(
|
|
||||||
user_id=user.id,
|
|
||||||
refresh_token_jti=refresh_payload.jti,
|
|
||||||
device_name=device_info.device_name,
|
|
||||||
device_id=device_info.device_id,
|
|
||||||
ip_address=device_info.ip_address,
|
|
||||||
user_agent=device_info.user_agent,
|
|
||||||
last_used_at=datetime.now(timezone.utc),
|
|
||||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
|
||||||
location_city=device_info.location_city,
|
|
||||||
location_country=device_info.location_country,
|
|
||||||
)
|
|
||||||
|
|
||||||
await session_crud.create_session(db, obj_in=session_data)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"User login successful: {user.email} from {device_info.device_name} "
|
|
||||||
f"(IP: {device_info.ip_address})"
|
|
||||||
)
|
|
||||||
except Exception as session_err:
|
|
||||||
# Log but don't fail login if session creation fails
|
|
||||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@@ -189,32 +210,8 @@ async def login_oauth(
|
|||||||
# Generate tokens
|
# Generate tokens
|
||||||
tokens = AuthService.create_tokens(user)
|
tokens = AuthService.create_tokens(user)
|
||||||
|
|
||||||
# Extract device information and create session record
|
# Create session record (best-effort, doesn't fail login)
|
||||||
# Session creation is best-effort - we don't fail login if it fails
|
await _create_login_session(db, request, user, tokens, login_type="oauth")
|
||||||
try:
|
|
||||||
device_info = extract_device_info(request)
|
|
||||||
|
|
||||||
# Decode refresh token to get JTI and expiration
|
|
||||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
|
||||||
|
|
||||||
session_data = SessionCreate(
|
|
||||||
user_id=user.id,
|
|
||||||
refresh_token_jti=refresh_payload.jti,
|
|
||||||
device_name=device_info.device_name or "API Client",
|
|
||||||
device_id=device_info.device_id,
|
|
||||||
ip_address=device_info.ip_address,
|
|
||||||
user_agent=device_info.user_agent,
|
|
||||||
last_used_at=datetime.now(timezone.utc),
|
|
||||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
|
||||||
location_city=device_info.location_city,
|
|
||||||
location_country=device_info.location_country,
|
|
||||||
)
|
|
||||||
|
|
||||||
await session_crud.create_session(db, obj_in=session_data)
|
|
||||||
|
|
||||||
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
|
|
||||||
except Exception as session_err:
|
|
||||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
|
||||||
|
|
||||||
# Return full token response with user data
|
# Return full token response with user data
|
||||||
return tokens
|
return tokens
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ def create_async_production_engine() -> AsyncEngine:
|
|||||||
if "postgresql" in async_url:
|
if "postgresql" in async_url:
|
||||||
engine_config["connect_args"] = {
|
engine_config["connect_args"] = {
|
||||||
"server_settings": {
|
"server_settings": {
|
||||||
"application_name": "eventspace",
|
"application_name": settings.PROJECT_NAME,
|
||||||
"timezone": "UTC",
|
"timezone": "UTC",
|
||||||
},
|
},
|
||||||
# asyncpg-specific settings
|
# asyncpg-specific settings
|
||||||
|
|||||||
@@ -113,6 +113,34 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Add request size limit middleware
|
||||||
|
@app.middleware("http")
|
||||||
|
async def limit_request_size(request: Request, call_next):
|
||||||
|
"""
|
||||||
|
Limit request body size to prevent DoS attacks via large payloads.
|
||||||
|
|
||||||
|
Max size: 10MB for file uploads and large payloads.
|
||||||
|
"""
|
||||||
|
MAX_REQUEST_SIZE = 10 * 1024 * 1024 # 10MB in bytes
|
||||||
|
|
||||||
|
content_length = request.headers.get("content-length")
|
||||||
|
if content_length and int(content_length) > MAX_REQUEST_SIZE:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||||
|
content={
|
||||||
|
"success": False,
|
||||||
|
"errors": [{
|
||||||
|
"code": "REQUEST_TOO_LARGE",
|
||||||
|
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
|
||||||
|
"field": None
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
# Add security headers middleware
|
# Add security headers middleware
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def add_security_headers(request: Request, call_next):
|
async def add_security_headers(request: Request, call_next):
|
||||||
@@ -286,48 +314,3 @@ async def health_check() -> JSONResponse:
|
|||||||
|
|
||||||
|
|
||||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
"""
|
|
||||||
Application startup event.
|
|
||||||
|
|
||||||
Sets up background jobs and scheduled tasks.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Skip scheduler in test environment
|
|
||||||
if os.getenv("IS_TEST", "False") == "True":
|
|
||||||
logger.info("Test environment detected - skipping scheduler")
|
|
||||||
return
|
|
||||||
|
|
||||||
from app.services.session_cleanup import cleanup_expired_sessions
|
|
||||||
|
|
||||||
# Schedule session cleanup job
|
|
||||||
# Runs daily at 2:00 AM server time
|
|
||||||
scheduler.add_job(
|
|
||||||
cleanup_expired_sessions,
|
|
||||||
'cron',
|
|
||||||
hour=2,
|
|
||||||
minute=0,
|
|
||||||
id='cleanup_expired_sessions',
|
|
||||||
replace_existing=True
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler.start()
|
|
||||||
logger.info("Scheduled jobs started: session cleanup (daily at 2 AM)")
|
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
|
||||||
async def shutdown_event():
|
|
||||||
"""
|
|
||||||
Application shutdown event.
|
|
||||||
|
|
||||||
Cleans up resources and stops background jobs.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
|
|
||||||
if os.getenv("IS_TEST", "False") != "True":
|
|
||||||
scheduler.shutdown()
|
|
||||||
logger.info("Scheduled jobs stopped")
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from app.core.auth import (
|
|||||||
TokenExpiredError,
|
TokenExpiredError,
|
||||||
TokenInvalidError
|
TokenInvalidError
|
||||||
)
|
)
|
||||||
|
from app.core.config import settings
|
||||||
from app.core.exceptions import AuthenticationError
|
from app.core.exceptions import AuthenticationError
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import Token, UserCreate, UserResponse
|
from app.schemas.users import Token, UserCreate, UserResponse
|
||||||
@@ -140,7 +141,7 @@ class AuthService:
|
|||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
refresh_token=refresh_token,
|
refresh_token=refresh_token,
|
||||||
user=user_response,
|
user=user_response,
|
||||||
expires_in=900 # 15 minutes in seconds (matching ACCESS_TOKEN_EXPIRE_MINUTES)
|
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # Convert minutes to seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user