Add session management API, cleanup service, and session-specific tests
- Introduced session management endpoints to list, revoke, and cleanup sessions per user. - Added cron-based job for periodic cleanup of expired sessions. - Implemented `CRUDSession` for session-specific database operations. - Integrated session cleanup startup and shutdown events in the application lifecycle. - Enhanced CORS configuration to include `X-Device-Id` for session tracking. - Added comprehensive integration tests for multi-device login, per-device logout, session listing, and cleanup logic.
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth, users
|
||||
from app.api.routes import auth, users, sessions
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# app/api/routes/auth.py
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
@@ -9,7 +11,7 @@ from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.users import (
|
||||
@@ -22,10 +24,13 @@ from app.schemas.users import (
|
||||
PasswordResetConfirm
|
||||
)
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionCreate, LogoutRequest
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.services.email_service import email_service
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
from app.utils.device import extract_device_info
|
||||
from app.crud.user import user as user_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
router = APIRouter()
|
||||
@@ -34,9 +39,13 @@ logger = logging.getLogger(__name__)
|
||||
# Initialize limiter for this router
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Use higher rate limits in test environment
|
||||
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
||||
@limiter.limit("5/minute")
|
||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
@@ -66,7 +75,7 @@ async def register_user(
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token, operation_id="login")
|
||||
@limiter.limit("10/minute")
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
login_data: LoginRequest,
|
||||
@@ -75,6 +84,8 @@ async def login(
|
||||
"""
|
||||
Login with username and password.
|
||||
|
||||
Creates a new session for this device.
|
||||
|
||||
Returns:
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
@@ -93,7 +104,38 @@ async def login(
|
||||
|
||||
# User is authenticated, generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
logger.info(f"User login successful: {user.email}")
|
||||
|
||||
# Extract device information and create session record
|
||||
# Session creation is best-effort - we don't fail login if it fails
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name,
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
f"User login successful: {user.email} from {device_info.device_name} "
|
||||
f"(IP: {device_info.ip_address})"
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||
|
||||
return tokens
|
||||
|
||||
except HTTPException:
|
||||
@@ -126,6 +168,8 @@ async def login_oauth(
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
|
||||
Creates a new session for this device.
|
||||
|
||||
Returns:
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
@@ -142,6 +186,33 @@ async def login_oauth(
|
||||
# Generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
|
||||
# Extract device information and create session record
|
||||
# Session creation is best-effort - we don't fail login if it fails
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name or "API Client",
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
|
||||
except Exception as session_err:
|
||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||
|
||||
# Format response for OAuth compatibility
|
||||
return {
|
||||
"access_token": tokens.access_token,
|
||||
@@ -175,12 +246,46 @@ async def refresh_token(
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
|
||||
Validates that the session is still active before issuing new tokens.
|
||||
|
||||
Returns:
|
||||
New access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
# Decode the refresh token to get the JTI
|
||||
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
|
||||
|
||||
# Check if session exists and is active
|
||||
session = session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if not session:
|
||||
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session has been revoked. Please log in again.",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Generate new tokens
|
||||
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||
|
||||
# Decode new refresh token to get new JTI
|
||||
new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
# Update session with new refresh token JTI and expiration
|
||||
try:
|
||||
session_crud.update_refresh_token(
|
||||
db,
|
||||
session=session,
|
||||
new_jti=new_refresh_payload.jti,
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc)
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True)
|
||||
# Continue anyway - tokens are already issued
|
||||
|
||||
return tokens
|
||||
|
||||
except TokenExpiredError:
|
||||
logger.warning("Token refresh failed: Token expired")
|
||||
raise HTTPException(
|
||||
@@ -195,6 +300,9 @@ async def refresh_token(
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (like session revoked)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||
raise HTTPException(
|
||||
@@ -347,3 +455,144 @@ def confirm_password_reset(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while resetting your password"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/logout",
|
||||
response_model=MessageResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Logout from Current Device",
|
||||
description="""
|
||||
Logout from the current device only.
|
||||
|
||||
Other devices will remain logged in.
|
||||
|
||||
Requires the refresh token to identify which session to terminate.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="logout"
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
def logout(
|
||||
request: Request,
|
||||
logout_request: LogoutRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from current device by deactivating the session.
|
||||
|
||||
Args:
|
||||
logout_request: Contains the refresh token for this session
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
try:
|
||||
# Decode refresh token to get JTI
|
||||
try:
|
||||
refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh")
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
# Even if token is expired/invalid, try to deactivate session
|
||||
logger.warning(f"Logout with invalid/expired token: {str(e)}")
|
||||
# Don't fail - return success anyway
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
|
||||
# Find the session by JTI
|
||||
session = session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if session:
|
||||
# Verify session belongs to current user (security check)
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to logout session {session.id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only logout your own sessions"
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
session_crud.deactivate(db, session_id=str(session.id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} logged out from {session.device_name} "
|
||||
f"(session {session.id})"
|
||||
)
|
||||
else:
|
||||
# Session not found - maybe already deleted or never existed
|
||||
# Return success anyway (idempotent)
|
||||
logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
# Don't expose error details
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/logout-all",
|
||||
response_model=MessageResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Logout from All Devices",
|
||||
description="""
|
||||
Logout from ALL devices.
|
||||
|
||||
This will terminate all active sessions for the current user.
|
||||
You will need to log in again on all devices.
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="logout_all"
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def logout_all(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from all devices by deactivating all user sessions.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message with count of sessions terminated
|
||||
"""
|
||||
try:
|
||||
# Deactivate all sessions for this user
|
||||
count = session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
||||
|
||||
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while logging out"
|
||||
)
|
||||
|
||||
251
backend/app/api/routes/sessions.py
Normal file
251
backend/app/api/routes/sessions.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Session management endpoints.
|
||||
|
||||
Allows users to view and manage their active sessions across devices.
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import decode_token
|
||||
from app.models.user import User
|
||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize limiter
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=SessionListResponse,
|
||||
summary="List My Active Sessions",
|
||||
description="""
|
||||
Get a list of all active sessions for the current user.
|
||||
|
||||
This shows where you're currently logged in.
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="list_my_sessions"
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
def list_my_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
List all active sessions for the current user.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of active sessions
|
||||
"""
|
||||
try:
|
||||
# Get all active sessions for user
|
||||
sessions = session_crud.get_user_sessions(
|
||||
db,
|
||||
user_id=str(current_user.id),
|
||||
active_only=True
|
||||
)
|
||||
|
||||
# Try to identify current session from Authorization header
|
||||
current_session_jti = None
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
try:
|
||||
access_token = auth_header.split(" ")[1]
|
||||
token_payload = decode_token(access_token)
|
||||
# Note: Access tokens don't have JTI by default, but we can try
|
||||
# For now, we'll mark current based on most recent activity
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Convert to response format
|
||||
session_responses = []
|
||||
for s in sessions:
|
||||
session_response = SessionResponse(
|
||||
id=s.id,
|
||||
device_name=s.device_name,
|
||||
device_id=s.device_id,
|
||||
ip_address=s.ip_address,
|
||||
location_city=s.location_city,
|
||||
location_country=s.location_country,
|
||||
last_used_at=s.last_used_at,
|
||||
created_at=s.created_at,
|
||||
expires_at=s.expires_at,
|
||||
is_current=(s == sessions[0] if sessions else False) # Most recent = current
|
||||
)
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
|
||||
|
||||
return SessionListResponse(
|
||||
sessions=session_responses,
|
||||
total=len(session_responses)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions"
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{session_id}",
|
||||
response_model=MessageResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Revoke Specific Session",
|
||||
description="""
|
||||
Revoke a specific session by ID.
|
||||
|
||||
This logs you out from that particular device.
|
||||
You can only revoke your own sessions.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="revoke_session"
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
def revoke_session(
|
||||
request: Request,
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Revoke a specific session by ID.
|
||||
|
||||
Args:
|
||||
session_id: UUID of the session to revoke
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
try:
|
||||
# Get the session
|
||||
session = session_crud.get(db, id=str(session_id))
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
message=f"Session {session_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# Verify session belongs to current user
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to revoke session {session_id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="You can only revoke your own sessions",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
session_crud.deactivate(db, session_id=str(session_id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} revoked session {session_id} "
|
||||
f"({session.device_name})"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}"
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session"
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/me/expired",
|
||||
response_model=MessageResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Cleanup Expired Sessions",
|
||||
description="""
|
||||
Remove expired sessions for the current user.
|
||||
|
||||
This is a cleanup operation to remove old session records.
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="cleanup_expired_sessions"
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def cleanup_expired_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Cleanup expired sessions for the current user.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message with count of sessions cleaned
|
||||
"""
|
||||
try:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Get all sessions for user
|
||||
all_sessions = session_crud.get_user_sessions(
|
||||
db,
|
||||
user_id=str(current_user.id),
|
||||
active_only=False
|
||||
)
|
||||
|
||||
# Delete expired and inactive sessions
|
||||
deleted_count = 0
|
||||
for s in all_sessions:
|
||||
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
|
||||
db.delete(s)
|
||||
deleted_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Cleaned up {deleted_count} expired sessions"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cleanup sessions"
|
||||
)
|
||||
Reference in New Issue
Block a user