Add init_db script for async database initialization and extensive tests for session management
- Added `init_db.py` to handle async database initialization with the creation of the first superuser if configured. - Introduced comprehensive tests for session management APIs, including session listing, revocation, and cleanup. - Enhanced CRUD session logic with UUID utilities and improved error handling.
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
92
backend/app/init_db.py
Normal file
92
backend/app/init_db.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# app/init_db.py
|
||||
"""
|
||||
Async database initialization script.
|
||||
|
||||
Creates the first superuser if configured and doesn't already exist.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, engine
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def init_db() -> Optional[User]:
|
||||
"""
|
||||
Initialize database with first superuser if settings are configured and user doesn't exist.
|
||||
|
||||
Returns:
|
||||
The created or existing superuser, or None if creation fails
|
||||
"""
|
||||
# Use default values if not set in environment variables
|
||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "admin123"
|
||||
|
||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
||||
logger.warning(
|
||||
"First superuser credentials not configured in settings. "
|
||||
f"Using defaults: {superuser_email}"
|
||||
)
|
||||
|
||||
async with SessionLocal() as session:
|
||||
try:
|
||||
# Check if superuser already exists
|
||||
existing_user = await user_crud.get_by_email(session, email=superuser_email)
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||
return existing_user
|
||||
|
||||
# Create superuser if doesn't exist
|
||||
user_in = UserCreate(
|
||||
email=superuser_email,
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
)
|
||||
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
logger.info(f"Created first superuser: {user.email}")
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Error initializing database: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for database initialization."""
|
||||
# Configure logging to show info logs
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
try:
|
||||
user = await init_db()
|
||||
if user:
|
||||
print(f"✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
except Exception as e:
|
||||
print(f"✗ Error initializing database: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Close the engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
367
backend/tests/api/test_sessions.py
Normal file
367
backend/tests/api/test_sessions.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# tests/api/test_sessions.py
|
||||
"""
|
||||
Comprehensive tests for session management API endpoints.
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi import status
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
# Disable rate limiting for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limit():
|
||||
"""Disable rate limiting for all tests in this module."""
|
||||
with patch('app.api.routes.sessions.limiter.enabled', False):
|
||||
yield
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def user_token(client, async_test_user):
|
||||
"""Create and return an access token for async_test_user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_test_user2(async_test_db):
|
||||
"""Create a second test user."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
user_data = UserCreate(
|
||||
email="testuser2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User2"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
class TestListMySessions:
|
||||
"""Tests for GET /api/v1/sessions/me endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_my_sessions_success(self, client, async_test_user, async_test_db, user_token):
|
||||
"""Test successfully listing user's active sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create some sessions for the user
|
||||
async with SessionLocal() as session:
|
||||
# Active session 1
|
||||
s1 = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="iPhone 13",
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0 (iPhone)",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
# Active session 2
|
||||
s2 = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="MacBook Pro",
|
||||
ip_address="192.168.1.101",
|
||||
user_agent="Mozilla/5.0 (Macintosh)",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
)
|
||||
# Inactive session (should not appear)
|
||||
s3 = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Old Device",
|
||||
ip_address="192.168.1.102",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=1)
|
||||
)
|
||||
session.add_all([s1, s2, s3])
|
||||
await session.commit()
|
||||
|
||||
# Make request
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "sessions" in data
|
||||
assert "total" in data
|
||||
# Note: Login creates a session, so we have 3 total (login + 2 created)
|
||||
assert data["total"] == 3
|
||||
assert len(data["sessions"]) == 3
|
||||
|
||||
# Check session data
|
||||
device_names = {s["device_name"] for s in data["sessions"]}
|
||||
assert "iPhone 13" in device_names
|
||||
assert "MacBook Pro" in device_names
|
||||
assert "Old Device" not in device_names
|
||||
|
||||
# First session should be marked as current
|
||||
assert data["sessions"][0]["is_current"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_my_sessions_with_login_session(self, client, async_test_user, user_token):
|
||||
"""Test listing sessions shows the login session."""
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
# Login creates a session, so we should have at least 1
|
||||
assert data["total"] >= 1
|
||||
assert len(data["sessions"]) >= 1
|
||||
assert data["sessions"][0]["is_current"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_my_sessions_unauthorized(self, client):
|
||||
"""Test listing sessions without authentication."""
|
||||
response = await client.get("/api/v1/sessions/me")
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestRevokeSession:
|
||||
"""Tests for DELETE /api/v1/sessions/{session_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session_success(self, client, async_test_user, async_test_db, user_token):
|
||||
"""Test successfully revoking a session."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session to revoke
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="iPad",
|
||||
ip_address="192.168.1.103",
|
||||
user_agent="Mozilla/5.0 (iPad)",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
session_id = user_session.id
|
||||
|
||||
# Revoke the session
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert "iPad" in data["message"]
|
||||
|
||||
# Verify session is deactivated
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.session import session as session_crud
|
||||
revoked_session = await session_crud.get(session, id=str(session_id))
|
||||
assert revoked_session.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session_not_found(self, client, user_token):
|
||||
"""Test revoking a non-existent session."""
|
||||
fake_id = uuid4()
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{fake_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert "errors" in data
|
||||
assert data["errors"][0]["code"] == "SYS_002" # NOT_FOUND error code
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session_unauthorized(self, client, async_test_db):
|
||||
"""Test revoking a session without authentication."""
|
||||
session_id = uuid4()
|
||||
response = await client.delete(f"/api/v1/sessions/{session_id}")
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session_belonging_to_other_user(
|
||||
self, client, async_test_user, async_test_user2, async_test_db, user_token
|
||||
):
|
||||
"""Test that users cannot revoke other users' sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session for user2
|
||||
async with SessionLocal() as session:
|
||||
other_user_session = UserSession(
|
||||
user_id=async_test_user2.id, # Different user
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Other User Device",
|
||||
ip_address="192.168.1.200",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(other_user_session)
|
||||
await session.commit()
|
||||
await session.refresh(other_user_session)
|
||||
session_id = other_user_session.id
|
||||
|
||||
# Try to revoke it as user1
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert "errors" in data
|
||||
assert data["errors"][0]["code"] == "AUTH_004" # INSUFFICIENT_PERMISSIONS
|
||||
assert "your own sessions" in data["errors"][0]["message"].lower()
|
||||
|
||||
|
||||
class TestCleanupExpiredSessions:
|
||||
"""Tests for DELETE /api/v1/sessions/me/expired endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_success(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test successfully cleaning up expired sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create expired and active sessions using CRUD to avoid greenlet issues
|
||||
from app.crud.session import session as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
async with SessionLocal() as db:
|
||||
# Expired session 1 (inactive and expired)
|
||||
e1_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Expired 1",
|
||||
ip_address="192.168.1.201",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
)
|
||||
e1 = await session_crud.create_session(db, obj_in=e1_data)
|
||||
e1.is_active = False
|
||||
db.add(e1)
|
||||
|
||||
# Expired session 2 (inactive and expired)
|
||||
e2_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Expired 2",
|
||||
ip_address="192.168.1.202",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
|
||||
)
|
||||
e2 = await session_crud.create_session(db, obj_in=e2_data)
|
||||
e2.is_active = False
|
||||
db.add(e2)
|
||||
|
||||
# Active session (should not be deleted)
|
||||
a1_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Active",
|
||||
ip_address="192.168.1.203",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=a1_data)
|
||||
await db.commit()
|
||||
|
||||
# Cleanup expired sessions
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
# Should have cleaned up 2 expired sessions
|
||||
assert "2" in data["message"] or data["message"].startswith("Cleaned up 2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_none_expired(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test cleanup when no sessions are expired."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create only active sessions using CRUD
|
||||
from app.crud.session import session as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
async with SessionLocal() as db:
|
||||
a1_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.210",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=a1_data)
|
||||
await db.commit()
|
||||
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
assert "0" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_unauthorized(self, client):
|
||||
"""Test cleanup without authentication."""
|
||||
response = await client.delete("/api/v1/sessions/me/expired")
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
Reference in New Issue
Block a user