refactor(backend): enforce route→service→repo layered architecture

- introduce custom repository exception hierarchy (DuplicateEntryError,
  IntegrityConstraintError, InvalidInputError) replacing raw ValueError
- eliminate all direct repository imports and raw SQL from route layer
- add UserService, SessionService, OrganizationService to service layer
- add get_stats/get_org_distribution service methods replacing admin inline SQL
- fix timing side-channel in authenticate_user via dummy bcrypt check
- replace SHA-256 client secret fallback with explicit InvalidClientError
- replace assert with InvalidGrantError in authorization code exchange
- replace N+1 token revocation loops with bulk UPDATE statements
- rename oauth account token fields (drop misleading 'encrypted' suffix)
- add Alembic migration 0003 for token field column rename
- add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
2026-02-27 09:32:57 +01:00
parent 0646c96b19
commit 98b455fdc3
62 changed files with 2933 additions and 1728 deletions

View File

@@ -10,6 +10,7 @@ from app.core.auth import (
get_password_hash,
verify_password,
)
from app.core.exceptions import DuplicateError
from app.models.user import User
from app.schemas.users import Token, UserCreate
from app.services.auth_service import AuthenticationError, AuthService
@@ -152,9 +153,9 @@ class TestAuthServiceUserCreation:
last_name="User",
)
# Should raise AuthenticationError
# Should raise DuplicateError for duplicate email
async with AsyncTestingSessionLocal() as session:
with pytest.raises(AuthenticationError):
with pytest.raises(DuplicateError):
await AuthService.create_user(db=session, user_data=user_data)

View File

@@ -269,18 +269,18 @@ class TestClientValidation:
async def test_validate_client_legacy_sha256_hash(
self, db, confidential_client_legacy_hash
):
"""Test validating a client with legacy SHA-256 hash (backward compatibility)."""
"""Test that legacy SHA-256 hash is rejected with clear error message."""
client, secret = confidential_client_legacy_hash
validated = await service.validate_client(db, client.client_id, secret)
assert validated.client_id == client.client_id
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
await service.validate_client(db, client.client_id, secret)
@pytest.mark.asyncio
async def test_validate_client_legacy_sha256_wrong_secret(
self, db, confidential_client_legacy_hash
):
"""Test legacy SHA-256 client rejects wrong secret."""
"""Test that legacy SHA-256 client with wrong secret is rejected."""
client, _ = confidential_client_legacy_hash
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
await service.validate_client(db, client.client_id, "wrong_secret")
def test_validate_redirect_uri_success(self, public_client):

View File

@@ -11,7 +11,8 @@ from uuid import uuid4
import pytest
from app.core.exceptions import AuthenticationError
from app.crud.oauth import oauth_account, oauth_state
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService

View File

@@ -0,0 +1,447 @@
# tests/services/test_organization_service.py
"""Tests for the OrganizationService class."""
import uuid
import pytest
import pytest_asyncio
from app.core.exceptions import NotFoundError
from app.models.user_organization import OrganizationRole
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
from app.services.organization_service import OrganizationService, organization_service
def _make_org_create(name=None, slug=None) -> OrganizationCreate:
"""Helper to create an OrganizationCreate schema with unique defaults."""
unique = uuid.uuid4().hex[:8]
return OrganizationCreate(
name=name or f"Test Org {unique}",
slug=slug or f"test-org-{unique}",
description="A test organization",
is_active=True,
settings={},
)
class TestGetOrganization:
"""Tests for OrganizationService.get_organization method."""
@pytest.mark.asyncio
async def test_get_organization_found(self, async_test_db, async_test_user):
"""Test getting an existing organization by ID returns the org."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
result = await organization_service.get_organization(
session, str(created.id)
)
assert result is not None
assert result.id == created.id
assert result.slug == created.slug
@pytest.mark.asyncio
async def test_get_organization_not_found(self, async_test_db):
"""Test getting a non-existent organization raises NotFoundError."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(NotFoundError):
await organization_service.get_organization(
session, str(uuid.uuid4())
)
class TestCreateOrganization:
"""Tests for OrganizationService.create_organization method."""
@pytest.mark.asyncio
async def test_create_organization(self, async_test_db, async_test_user):
"""Test creating a new organization returns the created org with correct fields."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_org_create()
async with AsyncTestingSessionLocal() as session:
result = await organization_service.create_organization(
session, obj_in=obj_in
)
assert result is not None
assert result.name == obj_in.name
assert result.slug == obj_in.slug
assert result.description == obj_in.description
assert result.is_active is True
class TestUpdateOrganization:
"""Tests for OrganizationService.update_organization method."""
@pytest.mark.asyncio
async def test_update_organization(self, async_test_db, async_test_user):
"""Test updating an organization name."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
org = await organization_service.get_organization(session, str(created.id))
updated = await organization_service.update_organization(
session,
org=org,
obj_in=OrganizationUpdate(name="Updated Org Name"),
)
assert updated.name == "Updated Org Name"
assert updated.id == created.id
@pytest.mark.asyncio
async def test_update_organization_with_dict(self, async_test_db, async_test_user):
"""Test updating an organization using a dict."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
org = await organization_service.get_organization(session, str(created.id))
updated = await organization_service.update_organization(
session,
org=org,
obj_in={"description": "Updated description"},
)
assert updated.description == "Updated description"
class TestRemoveOrganization:
"""Tests for OrganizationService.remove_organization method."""
@pytest.mark.asyncio
async def test_remove_organization(self, async_test_db, async_test_user):
"""Test permanently deleting an organization."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
org_id = str(created.id)
async with AsyncTestingSessionLocal() as session:
await organization_service.remove_organization(session, org_id)
async with AsyncTestingSessionLocal() as session:
with pytest.raises(NotFoundError):
await organization_service.get_organization(session, org_id)
class TestGetMemberCount:
"""Tests for OrganizationService.get_member_count method."""
@pytest.mark.asyncio
async def test_get_member_count_empty(self, async_test_db, async_test_user):
"""Test member count for org with no members is zero."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
count = await organization_service.get_member_count(
session, organization_id=created.id
)
assert count == 0
@pytest.mark.asyncio
async def test_get_member_count_with_member(self, async_test_db, async_test_user):
"""Test member count increases after adding a member."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
count = await organization_service.get_member_count(
session, organization_id=created.id
)
assert count == 1
class TestGetMultiWithMemberCounts:
"""Tests for OrganizationService.get_multi_with_member_counts method."""
@pytest.mark.asyncio
async def test_get_multi_with_member_counts(self, async_test_db, async_test_user):
"""Test listing organizations with member counts returns tuple."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
orgs, count = await organization_service.get_multi_with_member_counts(
session, skip=0, limit=10
)
assert isinstance(orgs, list)
assert isinstance(count, int)
assert count >= 1
@pytest.mark.asyncio
async def test_get_multi_with_member_counts_search(
self, async_test_db, async_test_user
):
"""Test listing organizations with a search filter."""
_test_engine, AsyncTestingSessionLocal = async_test_db
unique = uuid.uuid4().hex[:8]
org_name = f"Searchable Org {unique}"
async with AsyncTestingSessionLocal() as session:
await organization_service.create_organization(
session,
obj_in=OrganizationCreate(
name=org_name,
slug=f"searchable-org-{unique}",
is_active=True,
settings={},
),
)
async with AsyncTestingSessionLocal() as session:
orgs, count = await organization_service.get_multi_with_member_counts(
session, skip=0, limit=10, search=f"Searchable Org {unique}"
)
assert count >= 1
# Each element is a dict with key "organization" (an Organization obj) and "member_count"
names = [o["organization"].name for o in orgs]
assert org_name in names
class TestGetUserOrganizationsWithDetails:
"""Tests for OrganizationService.get_user_organizations_with_details method."""
@pytest.mark.asyncio
async def test_get_user_organizations_with_details(
self, async_test_db, async_test_user
):
"""Test getting organizations for a user returns list of dicts."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
orgs = await organization_service.get_user_organizations_with_details(
session, user_id=async_test_user.id
)
assert isinstance(orgs, list)
assert len(orgs) >= 1
class TestGetOrganizationMembers:
"""Tests for OrganizationService.get_organization_members method."""
@pytest.mark.asyncio
async def test_get_organization_members(self, async_test_db, async_test_user):
"""Test getting organization members returns paginated results."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
members, count = await organization_service.get_organization_members(
session, organization_id=created.id, skip=0, limit=10
)
assert isinstance(members, list)
assert isinstance(count, int)
assert count >= 1
class TestAddMember:
"""Tests for OrganizationService.add_member method."""
@pytest.mark.asyncio
async def test_add_member_default_role(self, async_test_db, async_test_user):
"""Test adding a user to an org with default MEMBER role."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
membership = await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
assert membership is not None
assert membership.user_id == async_test_user.id
assert membership.organization_id == created.id
assert membership.role == OrganizationRole.MEMBER
@pytest.mark.asyncio
async def test_add_member_admin_role(self, async_test_db, async_test_user):
"""Test adding a user to an org with ADMIN role."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
membership = await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
role=OrganizationRole.ADMIN,
)
assert membership.role == OrganizationRole.ADMIN
class TestRemoveMember:
"""Tests for OrganizationService.remove_member method."""
@pytest.mark.asyncio
async def test_remove_member(self, async_test_db, async_test_user):
"""Test removing a member from an org returns True."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
removed = await organization_service.remove_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
assert removed is True
@pytest.mark.asyncio
async def test_remove_member_not_found(self, async_test_db, async_test_user):
"""Test removing a non-member returns False."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
removed = await organization_service.remove_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
assert removed is False
class TestGetUserRoleInOrg:
"""Tests for OrganizationService.get_user_role_in_org method."""
@pytest.mark.asyncio
async def test_get_user_role_in_org(self, async_test_db, async_test_user):
"""Test getting a user's role in an org they belong to."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
role=OrganizationRole.MEMBER,
)
async with AsyncTestingSessionLocal() as session:
role = await organization_service.get_user_role_in_org(
session,
user_id=async_test_user.id,
organization_id=created.id,
)
assert role == OrganizationRole.MEMBER
@pytest.mark.asyncio
async def test_get_user_role_in_org_not_member(
self, async_test_db, async_test_user
):
"""Test getting role for a user not in the org returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
role = await organization_service.get_user_role_in_org(
session,
user_id=async_test_user.id,
organization_id=created.id,
)
assert role is None
class TestGetOrgDistribution:
"""Tests for OrganizationService.get_org_distribution method."""
@pytest.mark.asyncio
async def test_get_org_distribution_empty(self, async_test_db):
"""Test org distribution with no memberships returns empty list."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await organization_service.get_org_distribution(session, limit=6)
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_get_org_distribution_with_members(
self, async_test_db, async_test_user
):
"""Test org distribution returns org name and member count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
result = await organization_service.get_org_distribution(session, limit=6)
assert isinstance(result, list)
assert len(result) >= 1
entry = result[0]
assert "name" in entry
assert "value" in entry
assert entry["value"] >= 1

View File

@@ -0,0 +1,292 @@
# tests/services/test_session_service.py
"""Tests for the SessionService class."""
import uuid
from datetime import UTC, datetime, timedelta
import pytest
import pytest_asyncio
from app.schemas.sessions import SessionCreate
from app.services.session_service import SessionService, session_service
def _make_session_create(user_id, jti=None) -> SessionCreate:
"""Helper to build a SessionCreate with sensible defaults."""
now = datetime.now(UTC)
return SessionCreate(
user_id=user_id,
refresh_token_jti=jti or str(uuid.uuid4()),
ip_address="127.0.0.1",
user_agent="pytest/test",
device_name="Test Device",
device_id="test-device-id",
last_used_at=now,
expires_at=now + timedelta(days=7),
location_city="TestCity",
location_country="TestCountry",
)
class TestCreateSession:
"""Tests for SessionService.create_session method."""
@pytest.mark.asyncio
async def test_create_session(self, async_test_db, async_test_user):
"""Test creating a session returns a UserSession with correct fields."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
result = await session_service.create_session(session, obj_in=obj_in)
assert result is not None
assert result.user_id == async_test_user.id
assert result.refresh_token_jti == obj_in.refresh_token_jti
assert result.is_active is True
assert result.ip_address == "127.0.0.1"
class TestGetSession:
"""Tests for SessionService.get_session method."""
@pytest.mark.asyncio
async def test_get_session_found(self, async_test_db, async_test_user):
"""Test getting a session by ID returns the session."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_session(session, str(created.id))
assert result is not None
assert result.id == created.id
@pytest.mark.asyncio
async def test_get_session_not_found(self, async_test_db):
"""Test getting a non-existent session returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_session(session, str(uuid.uuid4()))
assert result is None
class TestGetUserSessions:
"""Tests for SessionService.get_user_sessions method."""
@pytest.mark.asyncio
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
"""Test getting active sessions for a user returns only active sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
sessions = await session_service.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True
)
assert isinstance(sessions, list)
assert len(sessions) >= 1
for s in sessions:
assert s.is_active is True
@pytest.mark.asyncio
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
"""Test getting all sessions (active and inactive) for a user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
await session_service.deactivate(session, session_id=str(created.id))
async with AsyncTestingSessionLocal() as session:
sessions = await session_service.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=False
)
assert isinstance(sessions, list)
assert len(sessions) >= 1
class TestGetActiveByJti:
"""Tests for SessionService.get_active_by_jti method."""
@pytest.mark.asyncio
async def test_get_active_by_jti_found(self, async_test_db, async_test_user):
"""Test getting an active session by JTI returns the session."""
_test_engine, AsyncTestingSessionLocal = async_test_db
jti = str(uuid.uuid4())
obj_in = _make_session_create(async_test_user.id, jti=jti)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_active_by_jti(session, jti=jti)
assert result is not None
assert result.refresh_token_jti == jti
assert result.is_active is True
@pytest.mark.asyncio
async def test_get_active_by_jti_not_found(self, async_test_db):
"""Test getting an active session by non-existent JTI returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_active_by_jti(
session, jti=str(uuid.uuid4())
)
assert result is None
class TestGetByJti:
"""Tests for SessionService.get_by_jti method."""
@pytest.mark.asyncio
async def test_get_by_jti_active(self, async_test_db, async_test_user):
"""Test getting a session (active or inactive) by JTI."""
_test_engine, AsyncTestingSessionLocal = async_test_db
jti = str(uuid.uuid4())
obj_in = _make_session_create(async_test_user.id, jti=jti)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_by_jti(session, jti=jti)
assert result is not None
assert result.refresh_token_jti == jti
class TestDeactivate:
"""Tests for SessionService.deactivate method."""
@pytest.mark.asyncio
async def test_deactivate_session(self, async_test_db, async_test_user):
"""Test deactivating a session sets is_active to False."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
session_id = str(created.id)
async with AsyncTestingSessionLocal() as session:
deactivated = await session_service.deactivate(
session, session_id=session_id
)
assert deactivated is not None
assert deactivated.is_active is False
class TestDeactivateAllUserSessions:
"""Tests for SessionService.deactivate_all_user_sessions method."""
@pytest.mark.asyncio
async def test_deactivate_all_user_sessions(self, async_test_db, async_test_user):
"""Test deactivating all sessions for a user returns count deactivated."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(
session, obj_in=_make_session_create(async_test_user.id)
)
await session_service.create_session(
session, obj_in=_make_session_create(async_test_user.id)
)
async with AsyncTestingSessionLocal() as session:
count = await session_service.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id)
)
assert count >= 2
async with AsyncTestingSessionLocal() as session:
active_sessions = await session_service.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True
)
assert len(active_sessions) == 0
class TestUpdateRefreshToken:
"""Tests for SessionService.update_refresh_token method."""
@pytest.mark.asyncio
async def test_update_refresh_token(self, async_test_db, async_test_user):
"""Test rotating a session's refresh token updates JTI and expiry."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
session_id = str(created.id)
new_jti = str(uuid.uuid4())
new_expires_at = datetime.now(UTC) + timedelta(days=14)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_session(session, session_id)
updated = await session_service.update_refresh_token(
session,
session=result,
new_jti=new_jti,
new_expires_at=new_expires_at,
)
assert updated.refresh_token_jti == new_jti
class TestCleanupExpiredForUser:
"""Tests for SessionService.cleanup_expired_for_user method."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user(self, async_test_db, async_test_user):
"""Test cleaning up expired inactive sessions returns count removed."""
_test_engine, AsyncTestingSessionLocal = async_test_db
now = datetime.now(UTC)
# Create a session that is already expired
obj_in = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid.uuid4()),
ip_address="127.0.0.1",
user_agent="pytest/test",
last_used_at=now - timedelta(days=8),
expires_at=now - timedelta(days=1),
)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
session_id = str(created.id)
# Deactivate it so it qualifies for cleanup (requires is_active=False AND expired)
async with AsyncTestingSessionLocal() as session:
await session_service.deactivate(session, session_id=session_id)
async with AsyncTestingSessionLocal() as session:
count = await session_service.cleanup_expired_for_user(
session, user_id=str(async_test_user.id)
)
assert isinstance(count, int)
assert count >= 1
class TestGetAllSessions:
"""Tests for SessionService.get_all_sessions method."""
@pytest.mark.asyncio
async def test_get_all_sessions(self, async_test_db, async_test_user):
"""Test getting all sessions with pagination returns tuple of list and count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
sessions, count = await session_service.get_all_sessions(
session, skip=0, limit=10, active_only=True, with_user=False
)
assert isinstance(sessions, list)
assert isinstance(count, int)
assert count >= 1
assert len(sessions) >= 1

View File

@@ -0,0 +1,214 @@
# tests/services/test_user_service.py
"""Tests for the UserService class."""
import uuid
import pytest
import pytest_asyncio
from sqlalchemy import select
from app.core.exceptions import NotFoundError
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
from app.services.user_service import UserService, user_service
class TestGetUser:
"""Tests for UserService.get_user method."""
@pytest.mark.asyncio
async def test_get_user_found(self, async_test_db, async_test_user):
"""Test getting an existing user by ID returns the user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_service.get_user(session, str(async_test_user.id))
assert result is not None
assert result.id == async_test_user.id
assert result.email == async_test_user.email
@pytest.mark.asyncio
async def test_get_user_not_found(self, async_test_db):
"""Test getting a non-existent user raises NotFoundError."""
_test_engine, AsyncTestingSessionLocal = async_test_db
non_existent_id = str(uuid.uuid4())
async with AsyncTestingSessionLocal() as session:
with pytest.raises(NotFoundError):
await user_service.get_user(session, non_existent_id)
class TestGetByEmail:
"""Tests for UserService.get_by_email method."""
@pytest.mark.asyncio
async def test_get_by_email_found(self, async_test_db, async_test_user):
"""Test getting an existing user by email returns the user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_service.get_by_email(session, async_test_user.email)
assert result is not None
assert result.id == async_test_user.id
assert result.email == async_test_user.email
@pytest.mark.asyncio
async def test_get_by_email_not_found(self, async_test_db):
"""Test getting a user by non-existent email returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_service.get_by_email(session, "nonexistent@example.com")
assert result is None
class TestCreateUser:
"""Tests for UserService.create_user method."""
@pytest.mark.asyncio
async def test_create_user(self, async_test_db):
"""Test creating a new user with valid data."""
_test_engine, AsyncTestingSessionLocal = async_test_db
unique_email = f"test_{uuid.uuid4()}@example.com"
user_data = UserCreate(
email=unique_email,
password="TestPassword123!",
first_name="New",
last_name="User",
)
async with AsyncTestingSessionLocal() as session:
result = await user_service.create_user(session, user_data)
assert result is not None
assert result.email == unique_email
assert result.first_name == "New"
assert result.last_name == "User"
assert result.is_active is True
class TestUpdateUser:
"""Tests for UserService.update_user method."""
@pytest.mark.asyncio
async def test_update_user(self, async_test_db, async_test_user):
"""Test updating a user's first_name."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_service.get_user(session, str(async_test_user.id))
updated = await user_service.update_user(
session,
user=user,
obj_in=UserUpdate(first_name="Updated"),
)
assert updated.first_name == "Updated"
assert updated.id == async_test_user.id
class TestSoftDeleteUser:
"""Tests for UserService.soft_delete_user method."""
@pytest.mark.asyncio
async def test_soft_delete_user(self, async_test_db, async_test_user):
"""Test soft-deleting a user sets deleted_at."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
await user_service.soft_delete_user(session, str(async_test_user.id))
async with AsyncTestingSessionLocal() as session:
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
assert user is not None
assert user.deleted_at is not None
class TestListUsers:
"""Tests for UserService.list_users method."""
@pytest.mark.asyncio
async def test_list_users(self, async_test_db, async_test_user):
"""Test listing users with pagination returns correct results."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
users, count = await user_service.list_users(session, skip=0, limit=10)
assert isinstance(users, list)
assert isinstance(count, int)
assert count >= 1
assert len(users) >= 1
@pytest.mark.asyncio
async def test_list_users_with_search(self, async_test_db, async_test_user):
"""Test listing users with email fragment search returns matching users."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Search by partial email fragment of the test user
email_fragment = async_test_user.email.split("@")[0]
async with AsyncTestingSessionLocal() as session:
users, count = await user_service.list_users(
session, skip=0, limit=10, search=email_fragment
)
assert isinstance(users, list)
assert count >= 1
emails = [u.email for u in users]
assert async_test_user.email in emails
class TestBulkUpdateStatus:
"""Tests for UserService.bulk_update_status method."""
@pytest.mark.asyncio
async def test_bulk_update_status(self, async_test_db, async_test_user):
"""Test bulk activating users returns correct count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_service.bulk_update_status(
session,
user_ids=[async_test_user.id],
is_active=True,
)
assert count >= 1
async with AsyncTestingSessionLocal() as session:
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
assert user is not None
assert user.is_active is True
class TestBulkSoftDelete:
"""Tests for UserService.bulk_soft_delete method."""
@pytest.mark.asyncio
async def test_bulk_soft_delete(self, async_test_db, async_test_user):
"""Test bulk soft-deleting users returns correct count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_service.bulk_soft_delete(
session,
user_ids=[async_test_user.id],
)
assert count >= 1
async with AsyncTestingSessionLocal() as session:
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
assert user is not None
assert user.deleted_at is not None
class TestGetStats:
"""Tests for UserService.get_stats method."""
@pytest.mark.asyncio
async def test_get_stats(self, async_test_db, async_test_user):
"""Test get_stats returns dict with expected keys and correct counts."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
stats = await user_service.get_stats(session)
assert "total_users" in stats
assert "active_count" in stats
assert "inactive_count" in stats
assert "all_users" in stats
assert stats["total_users"] >= 1
assert stats["active_count"] >= 1
assert isinstance(stats["all_users"], list)
assert len(stats["all_users"]) >= 1