Compare commits
2 Commits
e767920407
...
e19026453f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e19026453f | ||
|
|
b42a29faad |
@@ -0,0 +1,102 @@
|
||||
"""add_user_sessions_table
|
||||
|
||||
Revision ID: 549b50ea888d
|
||||
Revises: b76c725fc3cf
|
||||
Create Date: 2025-10-31 07:41:18.729544
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '549b50ea888d'
|
||||
down_revision: Union[str, None] = 'b76c725fc3cf'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_sessions table for per-device session management
|
||||
op.create_table(
|
||||
'user_sessions',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create foreign key to users table
|
||||
op.create_foreign_key(
|
||||
'fk_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
'users',
|
||||
['user_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
)
|
||||
|
||||
# Create indexes for performance
|
||||
# 1. Lookup session by refresh token JTI (most common query)
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti',
|
||||
'user_sessions',
|
||||
['refresh_token_jti'],
|
||||
unique=True
|
||||
)
|
||||
|
||||
# 2. Lookup sessions by user ID
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
['user_id']
|
||||
)
|
||||
|
||||
# 3. Composite index for active sessions by user
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_active',
|
||||
'user_sessions',
|
||||
['user_id', 'is_active']
|
||||
)
|
||||
|
||||
# 4. Index on expires_at for cleanup job
|
||||
op.create_index(
|
||||
'ix_user_sessions_expires_at',
|
||||
'user_sessions',
|
||||
['expires_at']
|
||||
)
|
||||
|
||||
# 5. Composite index for active session lookup by JTI
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti_active',
|
||||
'user_sessions',
|
||||
['refresh_token_jti', 'is_active']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes first
|
||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_expires_at', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_id', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_jti', table_name='user_sessions')
|
||||
|
||||
# Drop foreign key
|
||||
op.drop_constraint('fk_user_sessions_user_id', 'user_sessions', type_='foreignkey')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('user_sessions')
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth, users
|
||||
from app.api.routes import auth, users, sessions
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# app/api/routes/auth.py
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
@@ -9,7 +11,7 @@ from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.users import (
|
||||
@@ -22,10 +24,13 @@ from app.schemas.users import (
|
||||
PasswordResetConfirm
|
||||
)
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionCreate, LogoutRequest
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.services.email_service import email_service
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
from app.utils.device import extract_device_info
|
||||
from app.crud.user import user as user_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
router = APIRouter()
|
||||
@@ -34,9 +39,13 @@ logger = logging.getLogger(__name__)
|
||||
# Initialize limiter for this router
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Use higher rate limits in test environment
|
||||
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
||||
@limiter.limit("5/minute")
|
||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
@@ -66,7 +75,7 @@ async def register_user(
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token, operation_id="login")
|
||||
@limiter.limit("10/minute")
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
login_data: LoginRequest,
|
||||
@@ -75,6 +84,8 @@ async def login(
|
||||
"""
|
||||
Login with username and password.
|
||||
|
||||
Creates a new session for this device.
|
||||
|
||||
Returns:
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
@@ -93,7 +104,38 @@ async def login(
|
||||
|
||||
# User is authenticated, generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
logger.info(f"User login successful: {user.email}")
|
||||
|
||||
# Extract device information and create session record
|
||||
# Session creation is best-effort - we don't fail login if it fails
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name,
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
f"User login successful: {user.email} from {device_info.device_name} "
|
||||
f"(IP: {device_info.ip_address})"
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||
|
||||
return tokens
|
||||
|
||||
except HTTPException:
|
||||
@@ -126,6 +168,8 @@ async def login_oauth(
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
|
||||
Creates a new session for this device.
|
||||
|
||||
Returns:
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
@@ -142,6 +186,33 @@ async def login_oauth(
|
||||
# Generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
|
||||
# Extract device information and create session record
|
||||
# Session creation is best-effort - we don't fail login if it fails
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name or "API Client",
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
|
||||
except Exception as session_err:
|
||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||
|
||||
# Format response for OAuth compatibility
|
||||
return {
|
||||
"access_token": tokens.access_token,
|
||||
@@ -175,12 +246,46 @@ async def refresh_token(
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
|
||||
Validates that the session is still active before issuing new tokens.
|
||||
|
||||
Returns:
|
||||
New access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
# Decode the refresh token to get the JTI
|
||||
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
|
||||
|
||||
# Check if session exists and is active
|
||||
session = session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if not session:
|
||||
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session has been revoked. Please log in again.",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Generate new tokens
|
||||
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||
|
||||
# Decode new refresh token to get new JTI
|
||||
new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
# Update session with new refresh token JTI and expiration
|
||||
try:
|
||||
session_crud.update_refresh_token(
|
||||
db,
|
||||
session=session,
|
||||
new_jti=new_refresh_payload.jti,
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc)
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True)
|
||||
# Continue anyway - tokens are already issued
|
||||
|
||||
return tokens
|
||||
|
||||
except TokenExpiredError:
|
||||
logger.warning("Token refresh failed: Token expired")
|
||||
raise HTTPException(
|
||||
@@ -195,6 +300,9 @@ async def refresh_token(
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (like session revoked)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||
raise HTTPException(
|
||||
@@ -347,3 +455,144 @@ def confirm_password_reset(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while resetting your password"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/logout",
|
||||
response_model=MessageResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Logout from Current Device",
|
||||
description="""
|
||||
Logout from the current device only.
|
||||
|
||||
Other devices will remain logged in.
|
||||
|
||||
Requires the refresh token to identify which session to terminate.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="logout"
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
def logout(
|
||||
request: Request,
|
||||
logout_request: LogoutRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from current device by deactivating the session.
|
||||
|
||||
Args:
|
||||
logout_request: Contains the refresh token for this session
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
try:
|
||||
# Decode refresh token to get JTI
|
||||
try:
|
||||
refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh")
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
# Even if token is expired/invalid, try to deactivate session
|
||||
logger.warning(f"Logout with invalid/expired token: {str(e)}")
|
||||
# Don't fail - return success anyway
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
|
||||
# Find the session by JTI
|
||||
session = session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if session:
|
||||
# Verify session belongs to current user (security check)
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to logout session {session.id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only logout your own sessions"
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
session_crud.deactivate(db, session_id=str(session.id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} logged out from {session.device_name} "
|
||||
f"(session {session.id})"
|
||||
)
|
||||
else:
|
||||
# Session not found - maybe already deleted or never existed
|
||||
# Return success anyway (idempotent)
|
||||
logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
# Don't expose error details
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/logout-all",
|
||||
response_model=MessageResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Logout from All Devices",
|
||||
description="""
|
||||
Logout from ALL devices.
|
||||
|
||||
This will terminate all active sessions for the current user.
|
||||
You will need to log in again on all devices.
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="logout_all"
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def logout_all(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from all devices by deactivating all user sessions.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message with count of sessions terminated
|
||||
"""
|
||||
try:
|
||||
# Deactivate all sessions for this user
|
||||
count = session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
||||
|
||||
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while logging out"
|
||||
)
|
||||
|
||||
251
backend/app/api/routes/sessions.py
Normal file
251
backend/app/api/routes/sessions.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Session management endpoints.
|
||||
|
||||
Allows users to view and manage their active sessions across devices.
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import decode_token
|
||||
from app.models.user import User
|
||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize limiter
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=SessionListResponse,
|
||||
summary="List My Active Sessions",
|
||||
description="""
|
||||
Get a list of all active sessions for the current user.
|
||||
|
||||
This shows where you're currently logged in.
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="list_my_sessions"
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
def list_my_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
List all active sessions for the current user.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of active sessions
|
||||
"""
|
||||
try:
|
||||
# Get all active sessions for user
|
||||
sessions = session_crud.get_user_sessions(
|
||||
db,
|
||||
user_id=str(current_user.id),
|
||||
active_only=True
|
||||
)
|
||||
|
||||
# Try to identify current session from Authorization header
|
||||
current_session_jti = None
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
try:
|
||||
access_token = auth_header.split(" ")[1]
|
||||
token_payload = decode_token(access_token)
|
||||
# Note: Access tokens don't have JTI by default, but we can try
|
||||
# For now, we'll mark current based on most recent activity
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Convert to response format
|
||||
session_responses = []
|
||||
for s in sessions:
|
||||
session_response = SessionResponse(
|
||||
id=s.id,
|
||||
device_name=s.device_name,
|
||||
device_id=s.device_id,
|
||||
ip_address=s.ip_address,
|
||||
location_city=s.location_city,
|
||||
location_country=s.location_country,
|
||||
last_used_at=s.last_used_at,
|
||||
created_at=s.created_at,
|
||||
expires_at=s.expires_at,
|
||||
is_current=(s == sessions[0] if sessions else False) # Most recent = current
|
||||
)
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
|
||||
|
||||
return SessionListResponse(
|
||||
sessions=session_responses,
|
||||
total=len(session_responses)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions"
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{session_id}",
|
||||
response_model=MessageResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Revoke Specific Session",
|
||||
description="""
|
||||
Revoke a specific session by ID.
|
||||
|
||||
This logs you out from that particular device.
|
||||
You can only revoke your own sessions.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="revoke_session"
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
def revoke_session(
|
||||
request: Request,
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Revoke a specific session by ID.
|
||||
|
||||
Args:
|
||||
session_id: UUID of the session to revoke
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
try:
|
||||
# Get the session
|
||||
session = session_crud.get(db, id=str(session_id))
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
message=f"Session {session_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# Verify session belongs to current user
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to revoke session {session_id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="You can only revoke your own sessions",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
session_crud.deactivate(db, session_id=str(session_id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} revoked session {session_id} "
|
||||
f"({session.device_name})"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}"
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session"
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/me/expired",
|
||||
response_model=MessageResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Cleanup Expired Sessions",
|
||||
description="""
|
||||
Remove expired sessions for the current user.
|
||||
|
||||
This is a cleanup operation to remove old session records.
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="cleanup_expired_sessions"
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def cleanup_expired_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Cleanup expired sessions for the current user.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message with count of sessions cleaned
|
||||
"""
|
||||
try:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Get all sessions for user
|
||||
all_sessions = session_crud.get_user_sessions(
|
||||
db,
|
||||
user_id=str(current_user.id),
|
||||
active_only=False
|
||||
)
|
||||
|
||||
# Delete expired and inactive sessions
|
||||
deleted_count = 0
|
||||
for s in all_sessions:
|
||||
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
|
||||
db.delete(s)
|
||||
deleted_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Cleaned up {deleted_count} expired sessions"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cleanup sessions"
|
||||
)
|
||||
339
backend/app/crud/session.py
Normal file
339
backend/app/crud/session.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
CRUD operations for user sessions.
|
||||
"""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
import logging
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""CRUD operations for user sessions."""
|
||||
|
||||
def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
return db.query(UserSession).filter(
|
||||
UserSession.refresh_token_jti == jti
|
||||
).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get active session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
Active UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
return db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_sessions(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True
|
||||
) -> List[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
active_only: If True, return only active sessions
|
||||
|
||||
Returns:
|
||||
List of UserSession objects
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = db.query(UserSession).filter(UserSession.user_id == user_uuid)
|
||||
|
||||
if active_only:
|
||||
query = query.filter(UserSession.is_active == True)
|
||||
|
||||
return query.order_by(UserSession.last_used_at.desc()).all()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""
|
||||
Create a new user session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: SessionCreate schema with session data
|
||||
|
||||
Returns:
|
||||
Created UserSession
|
||||
|
||||
Raises:
|
||||
ValueError: If session creation fails
|
||||
"""
|
||||
try:
|
||||
db_obj = UserSession(
|
||||
user_id=obj_in.user_id,
|
||||
refresh_token_jti=obj_in.refresh_token_jti,
|
||||
device_name=obj_in.device_name,
|
||||
device_id=obj_in.device_id,
|
||||
ip_address=obj_in.ip_address,
|
||||
user_agent=obj_in.user_agent,
|
||||
last_used_at=obj_in.last_used_at,
|
||||
expires_at=obj_in.expires_at,
|
||||
is_active=True,
|
||||
location_city=obj_in.location_city,
|
||||
location_country=obj_in.location_country,
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
||||
f"(IP: {obj_in.ip_address})"
|
||||
)
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {str(e)}")
|
||||
|
||||
def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Deactivate a session (logout from device).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session_id: Session UUID
|
||||
|
||||
Returns:
|
||||
Deactivated UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
session = self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning(f"Session {session_id} not found for deactivation")
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
f"Session {session_id} deactivated for user {session.user_id} "
|
||||
f"({session.device_name})"
|
||||
)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def deactivate_all_user_sessions(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Deactivate all active sessions for a user (logout from all devices).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of sessions deactivated
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
count = db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
).update({"is_active": False})
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_last_used(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
session: UserSession
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update the last_used_at timestamp for a session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_refresh_token(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update session with new refresh token JTI and expiration.
|
||||
|
||||
Called during token refresh.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
new_jti: New refresh token JTI
|
||||
new_expires_at: New expiration datetime
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired sessions.
|
||||
|
||||
Deletes sessions that are:
|
||||
- Expired AND inactive
|
||||
- Older than keep_days
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
keep_days: Keep inactive sessions for this many days (for audit)
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||
|
||||
# Delete sessions that are:
|
||||
# 1. Expired (expires_at < now) AND inactive
|
||||
# AND
|
||||
# 2. Older than keep_days
|
||||
count = db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < datetime.now(timezone.utc),
|
||||
UserSession.created_at < cutoff_date
|
||||
)
|
||||
).delete()
|
||||
|
||||
db.commit()
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_session_count(self, db: Session, *, user_id: str) -> int:
|
||||
"""
|
||||
Get count of active sessions for a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of active sessions
|
||||
"""
|
||||
try:
|
||||
return db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
).count()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Create singleton instance
|
||||
session = CRUDSession(UserSession)
|
||||
@@ -62,6 +62,7 @@ app.add_middleware(
|
||||
"DNT",
|
||||
"Cache-Control",
|
||||
"X-Requested-With",
|
||||
"X-Device-Id", # For session management
|
||||
], # Explicit headers only
|
||||
expose_headers=["Content-Length"],
|
||||
max_age=600, # Cache preflight requests for 10 minutes
|
||||
@@ -171,3 +172,48 @@ async def health_check() -> JSONResponse:
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""
|
||||
Application startup event.
|
||||
|
||||
Sets up background jobs and scheduled tasks.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Skip scheduler in test environment
|
||||
if os.getenv("IS_TEST", "False") == "True":
|
||||
logger.info("Test environment detected - skipping scheduler")
|
||||
return
|
||||
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
# Schedule session cleanup job
|
||||
# Runs daily at 2:00 AM server time
|
||||
scheduler.add_job(
|
||||
cleanup_expired_sessions,
|
||||
'cron',
|
||||
hour=2,
|
||||
minute=0,
|
||||
id='cleanup_expired_sessions',
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Scheduled jobs started: session cleanup (daily at 2 AM)")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""
|
||||
Application shutdown event.
|
||||
|
||||
Cleans up resources and stops background jobs.
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.shutdown()
|
||||
logger.info("Scheduled jobs stopped")
|
||||
|
||||
@@ -6,9 +6,11 @@ Imports all models to ensure they're registered with SQLAlchemy.
|
||||
from app.core.database import Base
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# Import user model
|
||||
# Import models
|
||||
from .user import User
|
||||
from .user_session import UserSession
|
||||
|
||||
__all__ = [
|
||||
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||
'User',
|
||||
'User', 'UserSession',
|
||||
]
|
||||
80
backend/app/models/user_session.py
Normal file
80
backend/app/models/user_session.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
User session model for tracking per-device authentication sessions.
|
||||
|
||||
This allows users to:
|
||||
- See where they're logged in
|
||||
- Logout from specific devices
|
||||
- Manage their active sessions
|
||||
"""
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Tracks individual user sessions (per-device).
|
||||
|
||||
Each time a user logs in from a device, a new session is created.
|
||||
Sessions are identified by the refresh token JTI (JWT ID).
|
||||
"""
|
||||
__tablename__ = 'user_sessions'
|
||||
|
||||
# Foreign key to user
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
|
||||
# Refresh token identifier (JWT ID from the refresh token)
|
||||
refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# Device information
|
||||
device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook"
|
||||
device_id = Column(String(255), nullable=True) # Persistent device identifier (from client)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
|
||||
user_agent = Column(String(500), nullable=True) # Browser/app user agent
|
||||
|
||||
# Session timing
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=False)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Session state
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Geographic information (optional, can be populated from IP)
|
||||
location_city = Column(String(100), nullable=True)
|
||||
location_country = Column(String(100), nullable=True)
|
||||
|
||||
# Relationship to user
|
||||
user = relationship("User", backref="sessions")
|
||||
|
||||
# Composite indexes for performance (defined in migration)
|
||||
__table_args__ = (
|
||||
Index('ix_user_sessions_user_active', 'user_id', 'is_active'),
|
||||
Index('ix_user_sessions_jti_active', 'refresh_token_jti', 'is_active'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserSession {self.device_name} ({self.ip_address})>"
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if session has expired."""
|
||||
from datetime import datetime, timezone
|
||||
return self.expires_at < datetime.now(timezone.utc)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert session to dictionary for serialization."""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'user_id': str(self.user_id),
|
||||
'device_name': self.device_name,
|
||||
'device_id': self.device_id,
|
||||
'ip_address': self.ip_address,
|
||||
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
'expires_at': self.expires_at.isoformat() if self.expires_at else None,
|
||||
'is_active': self.is_active,
|
||||
'location_city': self.location_city,
|
||||
'location_country': self.location_country,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
133
backend/app/schemas/sessions.py
Normal file
133
backend/app/schemas/sessions.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Pydantic schemas for user session management.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class SessionBase(BaseModel):
|
||||
"""Base schema for user sessions."""
|
||||
device_name: Optional[str] = Field(None, max_length=255, description="Friendly device name")
|
||||
device_id: Optional[str] = Field(None, max_length=255, description="Persistent device identifier")
|
||||
|
||||
|
||||
class SessionCreate(SessionBase):
|
||||
"""Schema for creating a new session (internal use)."""
|
||||
user_id: UUID
|
||||
refresh_token_jti: str = Field(..., max_length=255)
|
||||
ip_address: Optional[str] = Field(None, max_length=45)
|
||||
user_agent: Optional[str] = Field(None, max_length=500)
|
||||
last_used_at: datetime
|
||||
expires_at: datetime
|
||||
location_city: Optional[str] = Field(None, max_length=100)
|
||||
location_country: Optional[str] = Field(None, max_length=100)
|
||||
|
||||
|
||||
class SessionUpdate(BaseModel):
|
||||
"""Schema for updating a session (internal use)."""
|
||||
last_used_at: Optional[datetime] = None
|
||||
is_active: Optional[bool] = None
|
||||
refresh_token_jti: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class SessionResponse(SessionBase):
|
||||
"""
|
||||
Schema for session responses to clients.
|
||||
|
||||
This is what users see when they list their active sessions.
|
||||
"""
|
||||
id: UUID
|
||||
ip_address: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
is_current: bool = Field(default=False, description="Whether this is the current session")
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"device_name": "iPhone 14",
|
||||
"device_id": "device-abc-123",
|
||||
"ip_address": "192.168.1.100",
|
||||
"location_city": "San Francisco",
|
||||
"location_country": "United States",
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
"""Response containing list of sessions."""
|
||||
sessions: list[SessionResponse]
|
||||
total: int = Field(..., description="Total number of active sessions")
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"sessions": [
|
||||
{
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"device_name": "iPhone 14",
|
||||
"ip_address": "192.168.1.100",
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request schema for logout endpoint."""
|
||||
refresh_token: str = Field(
|
||||
...,
|
||||
description="Refresh token for the session to logout from",
|
||||
min_length=10
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class DeviceInfo(BaseModel):
|
||||
"""Device information extracted from request."""
|
||||
device_name: Optional[str] = None
|
||||
device_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"device_name": "Chrome on MacBook",
|
||||
"device_id": "device-xyz-789",
|
||||
"ip_address": "192.168.1.50",
|
||||
"user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...",
|
||||
"location_city": "San Francisco",
|
||||
"location_country": "United States"
|
||||
}
|
||||
}
|
||||
)
|
||||
80
backend/app/services/session_cleanup.py
Normal file
80
backend/app/services/session_cleanup.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Background job for cleaning up expired sessions.
|
||||
|
||||
This service runs periodically to remove old session records from the database.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.crud.session import session as session_crud
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired and inactive sessions.
|
||||
|
||||
This removes sessions that are:
|
||||
- Inactive (is_active=False) AND
|
||||
- Expired (expires_at < now) AND
|
||||
- Older than keep_days
|
||||
|
||||
Args:
|
||||
keep_days: Keep inactive sessions for this many days for audit purposes
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
logger.info("Starting session cleanup job...")
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Use CRUD method to cleanup
|
||||
count = session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||
|
||||
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_session_statistics() -> dict:
|
||||
"""
|
||||
Get statistics about current sessions.
|
||||
|
||||
Returns:
|
||||
Dictionary with session stats
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.user_session import UserSession
|
||||
|
||||
total_sessions = db.query(UserSession).count()
|
||||
active_sessions = db.query(UserSession).filter(UserSession.is_active == True).count()
|
||||
expired_sessions = db.query(UserSession).filter(
|
||||
UserSession.expires_at < datetime.now(timezone.utc)
|
||||
).count()
|
||||
|
||||
stats = {
|
||||
"total": total_sessions,
|
||||
"active": active_sessions,
|
||||
"inactive": total_sessions - active_sessions,
|
||||
"expired": expired_sessions,
|
||||
}
|
||||
|
||||
logger.info(f"Session statistics: {stats}")
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
finally:
|
||||
db.close()
|
||||
233
backend/app/utils/device.py
Normal file
233
backend/app/utils/device.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Utility functions for extracting and parsing device information from HTTP requests.
|
||||
"""
|
||||
import re
|
||||
from typing import Optional
|
||||
from fastapi import Request
|
||||
|
||||
from app.schemas.sessions import DeviceInfo
|
||||
|
||||
|
||||
def extract_device_info(request: Request) -> DeviceInfo:
|
||||
"""
|
||||
Extract device information from the HTTP request.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
DeviceInfo object with parsed device information
|
||||
"""
|
||||
user_agent = request.headers.get('user-agent', '')
|
||||
|
||||
device_info = DeviceInfo(
|
||||
device_name=parse_device_name(user_agent),
|
||||
device_id=request.headers.get('x-device-id'), # Client must send this header
|
||||
ip_address=get_client_ip(request),
|
||||
user_agent=user_agent[:500] if user_agent else None, # Truncate to max length
|
||||
location_city=None, # Can be populated via IP geolocation service
|
||||
location_country=None, # Can be populated via IP geolocation service
|
||||
)
|
||||
|
||||
return device_info
|
||||
|
||||
|
||||
def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
"""
|
||||
Parse user agent string to extract a friendly device name.
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent header string
|
||||
|
||||
Returns:
|
||||
Friendly device name string or None
|
||||
|
||||
Examples:
|
||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)" -> "iPhone"
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)" -> "Mac"
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64)" -> "Windows PC"
|
||||
"""
|
||||
if not user_agent:
|
||||
return "Unknown device"
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Mobile devices (check first, as they can contain desktop patterns too)
|
||||
if 'iphone' in user_agent_lower:
|
||||
return "iPhone"
|
||||
elif 'ipad' in user_agent_lower:
|
||||
return "iPad"
|
||||
elif 'android' in user_agent_lower:
|
||||
# Try to extract device model
|
||||
android_match = re.search(r'android.*;\s*([^)]+)\s*build', user_agent_lower)
|
||||
if android_match:
|
||||
device_model = android_match.group(1).strip()
|
||||
return f"Android ({device_model.title()})"
|
||||
return "Android device"
|
||||
elif 'windows phone' in user_agent_lower:
|
||||
return "Windows Phone"
|
||||
|
||||
# Desktop operating systems
|
||||
elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower:
|
||||
# Try to extract browser
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Mac" if browser else "Mac"
|
||||
elif 'windows' in user_agent_lower:
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Windows" if browser else "Windows PC"
|
||||
elif 'linux' in user_agent_lower and 'android' not in user_agent_lower:
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Linux" if browser else "Linux"
|
||||
elif 'cros' in user_agent_lower:
|
||||
return "Chromebook"
|
||||
|
||||
# Tablets (not already caught)
|
||||
elif 'tablet' in user_agent_lower:
|
||||
return "Tablet"
|
||||
|
||||
# Smart TVs
|
||||
elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv', 'tv']):
|
||||
return "Smart TV"
|
||||
|
||||
# Game consoles
|
||||
elif 'playstation' in user_agent_lower:
|
||||
return "PlayStation"
|
||||
elif 'xbox' in user_agent_lower:
|
||||
return "Xbox"
|
||||
elif 'nintendo' in user_agent_lower:
|
||||
return "Nintendo"
|
||||
|
||||
# Fallback: just return browser name if detected
|
||||
browser = extract_browser(user_agent)
|
||||
if browser:
|
||||
return browser
|
||||
|
||||
return "Unknown device"
|
||||
|
||||
|
||||
def extract_browser(user_agent: str) -> Optional[str]:
|
||||
"""
|
||||
Extract browser name from user agent string.
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent header string
|
||||
|
||||
Returns:
|
||||
Browser name or None
|
||||
|
||||
Examples:
|
||||
"Mozilla/5.0 ... Chrome/96.0" -> "Chrome"
|
||||
"Mozilla/5.0 ... Firefox/94.0" -> "Firefox"
|
||||
"""
|
||||
if not user_agent:
|
||||
return None
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Check specific browsers (order matters - check Edge before Chrome!)
|
||||
if 'edg/' in user_agent_lower or 'edge/' in user_agent_lower:
|
||||
return "Edge"
|
||||
elif 'opr/' in user_agent_lower or 'opera' in user_agent_lower:
|
||||
return "Opera"
|
||||
elif 'chrome/' in user_agent_lower:
|
||||
return "Chrome"
|
||||
elif 'safari/' in user_agent_lower:
|
||||
# Make sure it's actually Safari, not Chrome (which also contains "Safari")
|
||||
if 'chrome' not in user_agent_lower:
|
||||
return "Safari"
|
||||
return None
|
||||
elif 'firefox/' in user_agent_lower:
|
||||
return "Firefox"
|
||||
elif 'msie' in user_agent_lower or 'trident/' in user_agent_lower:
|
||||
return "Internet Explorer"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract client IP address from request, considering proxy headers.
|
||||
|
||||
Checks X-Forwarded-For and X-Real-IP headers for proxy scenarios.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
Client IP address string or None
|
||||
|
||||
Notes:
|
||||
- In production behind a proxy/load balancer, X-Forwarded-For is often set
|
||||
- The first IP in X-Forwarded-For is typically the real client IP
|
||||
- request.client.host is fallback for direct connections
|
||||
"""
|
||||
# Check X-Forwarded-For (common in proxied environments)
|
||||
x_forwarded_for = request.headers.get('x-forwarded-for')
|
||||
if x_forwarded_for:
|
||||
# Get the first IP (original client)
|
||||
client_ip = x_forwarded_for.split(',')[0].strip()
|
||||
return client_ip
|
||||
|
||||
# Check X-Real-IP (used by some proxies like nginx)
|
||||
x_real_ip = request.headers.get('x-real-ip')
|
||||
if x_real_ip:
|
||||
return x_real_ip.strip()
|
||||
|
||||
# Fallback to direct connection IP
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_mobile_device(user_agent: str) -> bool:
|
||||
"""
|
||||
Check if the device is a mobile device based on user agent.
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent header string
|
||||
|
||||
Returns:
|
||||
True if mobile device, False otherwise
|
||||
"""
|
||||
if not user_agent:
|
||||
return False
|
||||
|
||||
mobile_patterns = [
|
||||
'mobile', 'android', 'iphone', 'ipad', 'ipod',
|
||||
'blackberry', 'windows phone', 'webos', 'opera mini',
|
||||
'iemobile', 'mobile safari'
|
||||
]
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
return any(pattern in user_agent_lower for pattern in mobile_patterns)
|
||||
|
||||
|
||||
def get_device_type(user_agent: str) -> str:
|
||||
"""
|
||||
Determine the general device type.
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent header string
|
||||
|
||||
Returns:
|
||||
Device type: "mobile", "tablet", "desktop", or "other"
|
||||
"""
|
||||
if not user_agent:
|
||||
return "other"
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Check for tablets first (they can contain "mobile" too)
|
||||
if 'ipad' in user_agent_lower or 'tablet' in user_agent_lower:
|
||||
return "tablet"
|
||||
|
||||
# Check for mobile
|
||||
if is_mobile_device(user_agent):
|
||||
return "mobile"
|
||||
|
||||
# Check for desktop OS patterns
|
||||
if any(os in user_agent_lower for os in ['windows', 'macintosh', 'linux', 'cros']):
|
||||
return "desktop"
|
||||
|
||||
return "other"
|
||||
@@ -1,6 +1,4 @@
|
||||
[pytest]
|
||||
env =
|
||||
IS_TEST=True
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
addopts = --disable-warnings
|
||||
|
||||
@@ -207,33 +207,54 @@ class TestRefreshToken:
|
||||
|
||||
def test_refresh_token_success(self, client, db_session):
|
||||
"""Test successful token refresh."""
|
||||
# Mock refresh to return tokens
|
||||
mock_tokens = MagicMock(
|
||||
access_token="new_access_token",
|
||||
refresh_token="new_refresh_token",
|
||||
token_type="bearer"
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
import uuid
|
||||
|
||||
# Create a test user
|
||||
test_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="refreshtest@example.com",
|
||||
password_hash=get_password_hash("TestPassword123"),
|
||||
first_name="Refresh",
|
||||
last_name="Test",
|
||||
is_active=True
|
||||
)
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
|
||||
# Login to get real tokens with a session
|
||||
login_response = client.post(
|
||||
"/auth/login",
|
||||
json={
|
||||
"email": "refreshtest@example.com",
|
||||
"password": "TestPassword123"
|
||||
}
|
||||
)
|
||||
assert login_response.status_code == 200
|
||||
tokens = login_response.json()
|
||||
|
||||
# Test refresh with real token
|
||||
response = client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": tokens["refresh_token"]
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(AuthService, 'refresh_tokens', return_value=mock_tokens):
|
||||
# Test request
|
||||
response = client.post(
|
||||
"/auth/refresh",
|
||||
json={
|
||||
"refresh_token": "valid_refresh_token"
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["access_token"] == "new_access_token"
|
||||
assert data["refresh_token"] == "new_refresh_token"
|
||||
assert data["token_type"] == "bearer"
|
||||
# Assertions
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
def test_refresh_token_expired(self, client, db_session):
|
||||
"""Test refresh with expired token."""
|
||||
# Mock refresh to raise expired token error
|
||||
with patch.object(AuthService, 'refresh_tokens',
|
||||
from app.api.routes import auth as auth_routes
|
||||
|
||||
# Mock decode_token to raise expired token error
|
||||
with patch.object(auth_routes, 'decode_token',
|
||||
side_effect=TokenExpiredError("Token expired")):
|
||||
# Test request
|
||||
response = client.post(
|
||||
@@ -245,7 +266,13 @@ class TestRefreshToken:
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 401
|
||||
assert "expired" in response.json()["detail"]
|
||||
# Check if it's in the new error format or old detail format
|
||||
response_data = response.json()
|
||||
if "errors" in response_data:
|
||||
assert "expired" in response_data["errors"][0]["message"].lower()
|
||||
else:
|
||||
assert "detail" in response_data
|
||||
assert "expired" in response_data["detail"].lower()
|
||||
|
||||
def test_refresh_token_invalid(self, client, db_session):
|
||||
"""Test refresh with invalid token."""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# tests/api/routes/test_rate_limiting.py
|
||||
import os
|
||||
import pytest
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.testclient import TestClient
|
||||
@@ -8,6 +9,12 @@ from app.api.routes.auth import router as auth_router, limiter
|
||||
from app.api.routes.users import router as users_router
|
||||
from app.core.database import get_db
|
||||
|
||||
# Skip all rate limiting tests when IS_TEST=True (rate limits are disabled in test mode)
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("IS_TEST", "False") == "True",
|
||||
reason="Rate limits are disabled in test mode (RATE_MULTIPLIER=100)"
|
||||
)
|
||||
|
||||
|
||||
# Mock the get_db dependency
|
||||
@pytest.fixture
|
||||
|
||||
421
backend/tests/api/test_session_management.py
Normal file
421
backend/tests/api/test_session_management.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
Integration tests for session management.
|
||||
|
||||
Tests the critical per-device logout functionality.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.utils.test_utils import setup_test_db, teardown_test_db
|
||||
import uuid
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_db_session():
|
||||
"""Create test database session."""
|
||||
test_engine, TestingSessionLocal = setup_test_db()
|
||||
with TestingSessionLocal() as session:
|
||||
yield session
|
||||
teardown_test_db(test_engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(test_db_session):
|
||||
"""Create test client with test database."""
|
||||
def override_get_db():
|
||||
try:
|
||||
yield test_db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(test_db_session):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="sessiontest@example.com",
|
||||
password_hash=get_password_hash("TestPassword123"),
|
||||
first_name="Session",
|
||||
last_name="Test",
|
||||
phone_number="+1234567890",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
preferences=None,
|
||||
)
|
||||
test_db_session.add(user)
|
||||
test_db_session.commit()
|
||||
test_db_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
class TestMultiDeviceLogin:
|
||||
"""Test multi-device login scenarios."""
|
||||
|
||||
def test_login_from_multiple_devices(self, client, test_user):
|
||||
"""Test that user can login from multiple devices simultaneously."""
|
||||
# Login from PC
|
||||
pc_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
},
|
||||
headers={"X-Device-Id": "pc-device-001"}
|
||||
)
|
||||
assert pc_response.status_code == 200
|
||||
pc_tokens = pc_response.json()
|
||||
assert "access_token" in pc_tokens
|
||||
assert "refresh_token" in pc_tokens
|
||||
pc_refresh = pc_tokens["refresh_token"]
|
||||
|
||||
# Login from Phone
|
||||
phone_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
},
|
||||
headers={"X-Device-Id": "phone-device-001"}
|
||||
)
|
||||
assert phone_response.status_code == 200
|
||||
phone_tokens = phone_response.json()
|
||||
assert "access_token" in phone_tokens
|
||||
assert "refresh_token" in phone_tokens
|
||||
phone_refresh = phone_tokens["refresh_token"]
|
||||
|
||||
# Verify both tokens are different
|
||||
assert pc_refresh != phone_refresh
|
||||
|
||||
# Both should be able to access protected endpoints
|
||||
pc_me = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {pc_tokens['access_token']}"}
|
||||
)
|
||||
assert pc_me.status_code == 200
|
||||
|
||||
phone_me = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {phone_tokens['access_token']}"}
|
||||
)
|
||||
assert phone_me.status_code == 200
|
||||
|
||||
def test_logout_from_one_device_does_not_affect_other(self, client, test_user):
|
||||
"""
|
||||
CRITICAL TEST: Logout from PC should NOT logout from Phone.
|
||||
|
||||
This is the main requirement for session management.
|
||||
"""
|
||||
# Login from PC
|
||||
pc_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
},
|
||||
headers={"X-Device-Id": "pc-device-001"}
|
||||
)
|
||||
assert pc_response.status_code == 200
|
||||
pc_tokens = pc_response.json()
|
||||
pc_access = pc_tokens["access_token"]
|
||||
pc_refresh = pc_tokens["refresh_token"]
|
||||
|
||||
# Login from Phone
|
||||
phone_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
},
|
||||
headers={"X-Device-Id": "phone-device-001"}
|
||||
)
|
||||
assert phone_response.status_code == 200
|
||||
phone_tokens = phone_response.json()
|
||||
phone_access = phone_tokens["access_token"]
|
||||
phone_refresh = phone_tokens["refresh_token"]
|
||||
|
||||
# Logout from PC
|
||||
logout_response = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": pc_refresh},
|
||||
headers={"Authorization": f"Bearer {pc_access}"}
|
||||
)
|
||||
assert logout_response.status_code == 200
|
||||
assert logout_response.json()["success"] == True
|
||||
|
||||
# PC refresh should fail (logged out)
|
||||
pc_refresh_response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": pc_refresh}
|
||||
)
|
||||
assert pc_refresh_response.status_code == 401
|
||||
response_data = pc_refresh_response.json()
|
||||
assert "revoked" in response_data["errors"][0]["message"].lower()
|
||||
|
||||
# Phone refresh should still work ✅ THIS IS THE CRITICAL ASSERTION
|
||||
phone_refresh_response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": phone_refresh}
|
||||
)
|
||||
assert phone_refresh_response.status_code == 200
|
||||
new_phone_tokens = phone_refresh_response.json()
|
||||
assert "access_token" in new_phone_tokens
|
||||
|
||||
# Phone can still access protected endpoints
|
||||
phone_me = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {new_phone_tokens['access_token']}"}
|
||||
)
|
||||
assert phone_me.status_code == 200
|
||||
assert phone_me.json()["email"] == "sessiontest@example.com"
|
||||
|
||||
def test_logout_all_devices(self, client, test_user):
|
||||
"""Test logging out from all devices simultaneously."""
|
||||
# Login from 3 devices
|
||||
devices = []
|
||||
for i, device_name in enumerate(["pc", "phone", "tablet"]):
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
},
|
||||
headers={"X-Device-Id": f"{device_name}-device-00{i}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
tokens = response.json()
|
||||
devices.append({
|
||||
"name": device_name,
|
||||
"access": tokens["access_token"],
|
||||
"refresh": tokens["refresh_token"]
|
||||
})
|
||||
|
||||
# Logout from all devices using first device's access token
|
||||
logout_all_response = client.post(
|
||||
"/api/v1/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {devices[0]['access']}"}
|
||||
)
|
||||
assert logout_all_response.status_code == 200
|
||||
assert "3" in logout_all_response.json()["message"] # 3 sessions terminated
|
||||
|
||||
# All refresh tokens should now fail
|
||||
for device in devices:
|
||||
refresh_response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": device["refresh"]}
|
||||
)
|
||||
assert refresh_response.status_code == 401
|
||||
|
||||
def test_list_active_sessions(self, client, test_user):
|
||||
"""Test listing active sessions."""
|
||||
# Login from 2 devices
|
||||
pc_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
},
|
||||
headers={"X-Device-Id": "pc-device-001"}
|
||||
)
|
||||
pc_tokens = pc_response.json()
|
||||
|
||||
phone_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
},
|
||||
headers={"X-Device-Id": "phone-device-001"}
|
||||
)
|
||||
|
||||
# List sessions
|
||||
sessions_response = client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {pc_tokens['access_token']}"}
|
||||
)
|
||||
assert sessions_response.status_code == 200
|
||||
sessions_data = sessions_response.json()
|
||||
assert sessions_data["total"] == 2
|
||||
assert len(sessions_data["sessions"]) == 2
|
||||
|
||||
# Check session details
|
||||
session = sessions_data["sessions"][0]
|
||||
assert "device_name" in session
|
||||
assert "ip_address" in session
|
||||
assert "last_used_at" in session
|
||||
assert "created_at" in session
|
||||
|
||||
def test_revoke_specific_session(self, client, test_user):
|
||||
"""Test revoking a specific session by ID."""
|
||||
# Login from 2 devices
|
||||
pc_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
},
|
||||
headers={"X-Device-Id": "pc-device-001"}
|
||||
)
|
||||
pc_tokens = pc_response.json()
|
||||
|
||||
phone_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
},
|
||||
headers={"X-Device-Id": "phone-device-001"}
|
||||
)
|
||||
phone_tokens = phone_response.json()
|
||||
|
||||
# List sessions to get IDs
|
||||
sessions_response = client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {pc_tokens['access_token']}"}
|
||||
)
|
||||
sessions = sessions_response.json()["sessions"]
|
||||
|
||||
# Find the phone session by device_id
|
||||
phone_session = next((s for s in sessions if s["device_id"] == "phone-device-001"), None)
|
||||
assert phone_session is not None, "Phone session not found in session list"
|
||||
session_id_to_revoke = phone_session["id"]
|
||||
revoke_response = client.delete(
|
||||
f"/api/v1/sessions/{session_id_to_revoke}",
|
||||
headers={"Authorization": f"Bearer {pc_tokens['access_token']}"}
|
||||
)
|
||||
assert revoke_response.status_code == 200
|
||||
|
||||
# Phone refresh should fail
|
||||
phone_refresh_response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": phone_tokens["refresh_token"]}
|
||||
)
|
||||
assert phone_refresh_response.status_code == 401
|
||||
|
||||
# PC refresh should still work
|
||||
pc_refresh_response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": pc_tokens["refresh_token"]}
|
||||
)
|
||||
assert pc_refresh_response.status_code == 200
|
||||
|
||||
|
||||
class TestSessionEdgeCases:
|
||||
"""Test edge cases and error scenarios."""
|
||||
|
||||
def test_logout_with_invalid_refresh_token(self, client, test_user):
|
||||
"""Test logout with invalid refresh token."""
|
||||
# Login first
|
||||
login_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
}
|
||||
)
|
||||
tokens = login_response.json()
|
||||
|
||||
# Try to logout with invalid refresh token
|
||||
logout_response = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "invalid_token"},
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
)
|
||||
# Should still return success (idempotent)
|
||||
assert logout_response.status_code == 200
|
||||
|
||||
def test_refresh_with_deactivated_session(self, client, test_user):
|
||||
"""Test refresh after session has been deactivated."""
|
||||
# Login
|
||||
login_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "sessiontest@example.com",
|
||||
"password": "TestPassword123"
|
||||
}
|
||||
)
|
||||
tokens = login_response.json()
|
||||
|
||||
# Logout
|
||||
client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
)
|
||||
|
||||
# Try to refresh with deactivated session
|
||||
refresh_response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
)
|
||||
assert refresh_response.status_code == 401
|
||||
response_data = refresh_response.json()
|
||||
assert "revoked" in response_data["errors"][0]["message"].lower()
|
||||
|
||||
def test_cannot_revoke_other_users_session(self, client, test_db_session):
|
||||
"""Test that users cannot revoke other users' sessions."""
|
||||
# Create two users
|
||||
user1 = User(
|
||||
id=uuid.uuid4(),
|
||||
email="user1@example.com",
|
||||
password_hash=get_password_hash("TestPassword123"),
|
||||
first_name="User",
|
||||
last_name="One",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
user2 = User(
|
||||
id=uuid.uuid4(),
|
||||
email="user2@example.com",
|
||||
password_hash=get_password_hash("TestPassword123"),
|
||||
first_name="User",
|
||||
last_name="Two",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
test_db_session.add_all([user1, user2])
|
||||
test_db_session.commit()
|
||||
|
||||
# User1 login
|
||||
user1_login = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "user1@example.com", "password": "TestPassword123"}
|
||||
)
|
||||
user1_tokens = user1_login.json()
|
||||
|
||||
# User2 login
|
||||
user2_login = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "user2@example.com", "password": "TestPassword123"}
|
||||
)
|
||||
|
||||
# User1 gets their sessions
|
||||
user1_sessions = client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user1_tokens['access_token']}"}
|
||||
)
|
||||
user1_session_id = user1_sessions.json()["sessions"][0]["id"]
|
||||
|
||||
# User2 lists their sessions
|
||||
user2_sessions = client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user2_login.json()['access_token']}"}
|
||||
)
|
||||
user2_session_id = user2_sessions.json()["sessions"][0]["id"]
|
||||
|
||||
# User1 tries to revoke User2's session (should fail)
|
||||
revoke_response = client.delete(
|
||||
f"/api/v1/sessions/{user2_session_id}",
|
||||
headers={"Authorization": f"Bearer {user1_tokens['access_token']}"}
|
||||
)
|
||||
assert revoke_response.status_code == 403
|
||||
@@ -1,10 +1,15 @@
|
||||
# tests/conftest.py
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Set IS_TEST environment variable BEFORE importing app
|
||||
# This prevents the scheduler from starting during tests
|
||||
os.environ["IS_TEST"] = "True"
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
Reference in New Issue
Block a user