Files
pragma-stack/backend/tests/services/test_session_service.py
Felipe Cardoso 98b455fdc3 refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError,
  IntegrityConstraintError, InvalidInputError) replacing raw ValueError
- eliminate all direct repository imports and raw SQL from route layer
- add UserService, SessionService, OrganizationService to service layer
- add get_stats/get_org_distribution service methods replacing admin inline SQL
- fix timing side-channel in authenticate_user via dummy bcrypt check
- replace SHA-256 client secret fallback with explicit InvalidClientError
- replace assert with InvalidGrantError in authorization code exchange
- replace N+1 token revocation loops with bulk UPDATE statements
- rename oauth account token fields (drop misleading 'encrypted' suffix)
- add Alembic migration 0003 for token field column rename
- add 45 new service/repository tests; 975 passing, 94% coverage
2026-02-27 09:32:57 +01:00

293 lines
12 KiB
Python

# tests/services/test_session_service.py
"""Tests for the SessionService class."""
import uuid
from datetime import UTC, datetime, timedelta
import pytest
import pytest_asyncio
from app.schemas.sessions import SessionCreate
from app.services.session_service import SessionService, session_service
def _make_session_create(user_id, jti=None) -> SessionCreate:
"""Helper to build a SessionCreate with sensible defaults."""
now = datetime.now(UTC)
return SessionCreate(
user_id=user_id,
refresh_token_jti=jti or str(uuid.uuid4()),
ip_address="127.0.0.1",
user_agent="pytest/test",
device_name="Test Device",
device_id="test-device-id",
last_used_at=now,
expires_at=now + timedelta(days=7),
location_city="TestCity",
location_country="TestCountry",
)
class TestCreateSession:
"""Tests for SessionService.create_session method."""
@pytest.mark.asyncio
async def test_create_session(self, async_test_db, async_test_user):
"""Test creating a session returns a UserSession with correct fields."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
result = await session_service.create_session(session, obj_in=obj_in)
assert result is not None
assert result.user_id == async_test_user.id
assert result.refresh_token_jti == obj_in.refresh_token_jti
assert result.is_active is True
assert result.ip_address == "127.0.0.1"
class TestGetSession:
"""Tests for SessionService.get_session method."""
@pytest.mark.asyncio
async def test_get_session_found(self, async_test_db, async_test_user):
"""Test getting a session by ID returns the session."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_session(session, str(created.id))
assert result is not None
assert result.id == created.id
@pytest.mark.asyncio
async def test_get_session_not_found(self, async_test_db):
"""Test getting a non-existent session returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_session(session, str(uuid.uuid4()))
assert result is None
class TestGetUserSessions:
"""Tests for SessionService.get_user_sessions method."""
@pytest.mark.asyncio
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
"""Test getting active sessions for a user returns only active sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
sessions = await session_service.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True
)
assert isinstance(sessions, list)
assert len(sessions) >= 1
for s in sessions:
assert s.is_active is True
@pytest.mark.asyncio
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
"""Test getting all sessions (active and inactive) for a user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
await session_service.deactivate(session, session_id=str(created.id))
async with AsyncTestingSessionLocal() as session:
sessions = await session_service.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=False
)
assert isinstance(sessions, list)
assert len(sessions) >= 1
class TestGetActiveByJti:
"""Tests for SessionService.get_active_by_jti method."""
@pytest.mark.asyncio
async def test_get_active_by_jti_found(self, async_test_db, async_test_user):
"""Test getting an active session by JTI returns the session."""
_test_engine, AsyncTestingSessionLocal = async_test_db
jti = str(uuid.uuid4())
obj_in = _make_session_create(async_test_user.id, jti=jti)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_active_by_jti(session, jti=jti)
assert result is not None
assert result.refresh_token_jti == jti
assert result.is_active is True
@pytest.mark.asyncio
async def test_get_active_by_jti_not_found(self, async_test_db):
"""Test getting an active session by non-existent JTI returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_active_by_jti(
session, jti=str(uuid.uuid4())
)
assert result is None
class TestGetByJti:
"""Tests for SessionService.get_by_jti method."""
@pytest.mark.asyncio
async def test_get_by_jti_active(self, async_test_db, async_test_user):
"""Test getting a session (active or inactive) by JTI."""
_test_engine, AsyncTestingSessionLocal = async_test_db
jti = str(uuid.uuid4())
obj_in = _make_session_create(async_test_user.id, jti=jti)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_by_jti(session, jti=jti)
assert result is not None
assert result.refresh_token_jti == jti
class TestDeactivate:
"""Tests for SessionService.deactivate method."""
@pytest.mark.asyncio
async def test_deactivate_session(self, async_test_db, async_test_user):
"""Test deactivating a session sets is_active to False."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
session_id = str(created.id)
async with AsyncTestingSessionLocal() as session:
deactivated = await session_service.deactivate(
session, session_id=session_id
)
assert deactivated is not None
assert deactivated.is_active is False
class TestDeactivateAllUserSessions:
"""Tests for SessionService.deactivate_all_user_sessions method."""
@pytest.mark.asyncio
async def test_deactivate_all_user_sessions(self, async_test_db, async_test_user):
"""Test deactivating all sessions for a user returns count deactivated."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(
session, obj_in=_make_session_create(async_test_user.id)
)
await session_service.create_session(
session, obj_in=_make_session_create(async_test_user.id)
)
async with AsyncTestingSessionLocal() as session:
count = await session_service.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id)
)
assert count >= 2
async with AsyncTestingSessionLocal() as session:
active_sessions = await session_service.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True
)
assert len(active_sessions) == 0
class TestUpdateRefreshToken:
"""Tests for SessionService.update_refresh_token method."""
@pytest.mark.asyncio
async def test_update_refresh_token(self, async_test_db, async_test_user):
"""Test rotating a session's refresh token updates JTI and expiry."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
session_id = str(created.id)
new_jti = str(uuid.uuid4())
new_expires_at = datetime.now(UTC) + timedelta(days=14)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_session(session, session_id)
updated = await session_service.update_refresh_token(
session,
session=result,
new_jti=new_jti,
new_expires_at=new_expires_at,
)
assert updated.refresh_token_jti == new_jti
class TestCleanupExpiredForUser:
"""Tests for SessionService.cleanup_expired_for_user method."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user(self, async_test_db, async_test_user):
"""Test cleaning up expired inactive sessions returns count removed."""
_test_engine, AsyncTestingSessionLocal = async_test_db
now = datetime.now(UTC)
# Create a session that is already expired
obj_in = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid.uuid4()),
ip_address="127.0.0.1",
user_agent="pytest/test",
last_used_at=now - timedelta(days=8),
expires_at=now - timedelta(days=1),
)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
session_id = str(created.id)
# Deactivate it so it qualifies for cleanup (requires is_active=False AND expired)
async with AsyncTestingSessionLocal() as session:
await session_service.deactivate(session, session_id=session_id)
async with AsyncTestingSessionLocal() as session:
count = await session_service.cleanup_expired_for_user(
session, user_id=str(async_test_user.id)
)
assert isinstance(count, int)
assert count >= 1
class TestGetAllSessions:
"""Tests for SessionService.get_all_sessions method."""
@pytest.mark.asyncio
async def test_get_all_sessions(self, async_test_db, async_test_user):
"""Test getting all sessions with pagination returns tuple of list and count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
sessions, count = await session_service.get_all_sessions(
session, skip=0, limit=10, active_only=True, with_user=False
)
assert isinstance(sessions, list)
assert isinstance(count, int)
assert count >= 1
assert len(sessions) >= 1