refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError, IntegrityConstraintError, InvalidInputError) replacing raw ValueError - eliminate all direct repository imports and raw SQL from route layer - add UserService, SessionService, OrganizationService to service layer - add get_stats/get_org_distribution service methods replacing admin inline SQL - fix timing side-channel in authenticate_user via dummy bcrypt check - replace SHA-256 client secret fallback with explicit InvalidClientError - replace assert with InvalidGrantError in authorization code exchange - replace N+1 token revocation loops with bulk UPDATE statements - rename oauth account token fields (drop misleading 'encrypted' suffix) - add Alembic migration 0003 for token field column rename - add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
292
backend/tests/services/test_session_service.py
Normal file
292
backend/tests/services/test_session_service.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# tests/services/test_session_service.py
|
||||
"""Tests for the SessionService class."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.schemas.sessions import SessionCreate
|
||||
from app.services.session_service import SessionService, session_service
|
||||
|
||||
|
||||
def _make_session_create(user_id, jti=None) -> SessionCreate:
|
||||
"""Helper to build a SessionCreate with sensible defaults."""
|
||||
now = datetime.now(UTC)
|
||||
return SessionCreate(
|
||||
user_id=user_id,
|
||||
refresh_token_jti=jti or str(uuid.uuid4()),
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="pytest/test",
|
||||
device_name="Test Device",
|
||||
device_id="test-device-id",
|
||||
last_used_at=now,
|
||||
expires_at=now + timedelta(days=7),
|
||||
location_city="TestCity",
|
||||
location_country="TestCountry",
|
||||
)
|
||||
|
||||
|
||||
class TestCreateSession:
|
||||
"""Tests for SessionService.create_session method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self, async_test_db, async_test_user):
|
||||
"""Test creating a session returns a UserSession with correct fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.create_session(session, obj_in=obj_in)
|
||||
assert result is not None
|
||||
assert result.user_id == async_test_user.id
|
||||
assert result.refresh_token_jti == obj_in.refresh_token_jti
|
||||
assert result.is_active is True
|
||||
assert result.ip_address == "127.0.0.1"
|
||||
|
||||
|
||||
class TestGetSession:
|
||||
"""Tests for SessionService.get_session method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_found(self, async_test_db, async_test_user):
|
||||
"""Test getting a session by ID returns the session."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_session(session, str(created.id))
|
||||
assert result is not None
|
||||
assert result.id == created.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_not_found(self, async_test_db):
|
||||
"""Test getting a non-existent session returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_session(session, str(uuid.uuid4()))
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetUserSessions:
|
||||
"""Tests for SessionService.get_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
|
||||
"""Test getting active sessions for a user returns only active sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sessions = await session_service.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=True
|
||||
)
|
||||
assert isinstance(sessions, list)
|
||||
assert len(sessions) >= 1
|
||||
for s in sessions:
|
||||
assert s.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
|
||||
"""Test getting all sessions (active and inactive) for a user."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
await session_service.deactivate(session, session_id=str(created.id))
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sessions = await session_service.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=False
|
||||
)
|
||||
assert isinstance(sessions, list)
|
||||
assert len(sessions) >= 1
|
||||
|
||||
|
||||
class TestGetActiveByJti:
|
||||
"""Tests for SessionService.get_active_by_jti method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_found(self, async_test_db, async_test_user):
|
||||
"""Test getting an active session by JTI returns the session."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
jti = str(uuid.uuid4())
|
||||
obj_in = _make_session_create(async_test_user.id, jti=jti)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_active_by_jti(session, jti=jti)
|
||||
assert result is not None
|
||||
assert result.refresh_token_jti == jti
|
||||
assert result.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_not_found(self, async_test_db):
|
||||
"""Test getting an active session by non-existent JTI returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_active_by_jti(
|
||||
session, jti=str(uuid.uuid4())
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetByJti:
|
||||
"""Tests for SessionService.get_by_jti method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_active(self, async_test_db, async_test_user):
|
||||
"""Test getting a session (active or inactive) by JTI."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
jti = str(uuid.uuid4())
|
||||
obj_in = _make_session_create(async_test_user.id, jti=jti)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_by_jti(session, jti=jti)
|
||||
assert result is not None
|
||||
assert result.refresh_token_jti == jti
|
||||
|
||||
|
||||
class TestDeactivate:
|
||||
"""Tests for SessionService.deactivate method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_session(self, async_test_db, async_test_user):
|
||||
"""Test deactivating a session sets is_active to False."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
session_id = str(created.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
deactivated = await session_service.deactivate(
|
||||
session, session_id=session_id
|
||||
)
|
||||
assert deactivated is not None
|
||||
assert deactivated.is_active is False
|
||||
|
||||
|
||||
class TestDeactivateAllUserSessions:
|
||||
"""Tests for SessionService.deactivate_all_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_user_sessions(self, async_test_db, async_test_user):
|
||||
"""Test deactivating all sessions for a user returns count deactivated."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(
|
||||
session, obj_in=_make_session_create(async_test_user.id)
|
||||
)
|
||||
await session_service.create_session(
|
||||
session, obj_in=_make_session_create(async_test_user.id)
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_service.deactivate_all_user_sessions(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count >= 2
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_sessions = await session_service.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=True
|
||||
)
|
||||
assert len(active_sessions) == 0
|
||||
|
||||
|
||||
class TestUpdateRefreshToken:
|
||||
"""Tests for SessionService.update_refresh_token method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token(self, async_test_db, async_test_user):
|
||||
"""Test rotating a session's refresh token updates JTI and expiry."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
session_id = str(created.id)
|
||||
|
||||
new_jti = str(uuid.uuid4())
|
||||
new_expires_at = datetime.now(UTC) + timedelta(days=14)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_session(session, session_id)
|
||||
updated = await session_service.update_refresh_token(
|
||||
session,
|
||||
session=result,
|
||||
new_jti=new_jti,
|
||||
new_expires_at=new_expires_at,
|
||||
)
|
||||
assert updated.refresh_token_jti == new_jti
|
||||
|
||||
|
||||
class TestCleanupExpiredForUser:
|
||||
"""Tests for SessionService.cleanup_expired_for_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user(self, async_test_db, async_test_user):
|
||||
"""Test cleaning up expired inactive sessions returns count removed."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
now = datetime.now(UTC)
|
||||
# Create a session that is already expired
|
||||
obj_in = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid.uuid4()),
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="pytest/test",
|
||||
last_used_at=now - timedelta(days=8),
|
||||
expires_at=now - timedelta(days=1),
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
session_id = str(created.id)
|
||||
|
||||
# Deactivate it so it qualifies for cleanup (requires is_active=False AND expired)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.deactivate(session, session_id=session_id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_service.cleanup_expired_for_user(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 1
|
||||
|
||||
|
||||
class TestGetAllSessions:
|
||||
"""Tests for SessionService.get_all_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_sessions(self, async_test_db, async_test_user):
|
||||
"""Test getting all sessions with pagination returns tuple of list and count."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sessions, count = await session_service.get_all_sessions(
|
||||
session, skip=0, limit=10, active_only=True, with_user=False
|
||||
)
|
||||
assert isinstance(sessions, list)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 1
|
||||
assert len(sessions) >= 1
|
||||
Reference in New Issue
Block a user