forked from cardosofelipe/pragma-stack
refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError, IntegrityConstraintError, InvalidInputError) replacing raw ValueError - eliminate all direct repository imports and raw SQL from route layer - add UserService, SessionService, OrganizationService to service layer - add get_stats/get_org_distribution service methods replacing admin inline SQL - fix timing side-channel in authenticate_user via dummy bcrypt check - replace SHA-256 client secret fallback with explicit InvalidClientError - replace assert with InvalidGrantError in authorization code exchange - replace N+1 token revocation loops with bulk UPDATE statements - rename oauth account token fields (drop misleading 'encrypted' suffix) - add Alembic migration 0003 for token field column rename - add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
@@ -147,7 +147,7 @@ class TestAdminCreateUser:
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
|
||||
class TestAdminGetUser:
|
||||
@@ -565,7 +565,7 @@ class TestAdminCreateOrganization:
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
|
||||
class TestAdminGetOrganization:
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestAdminListUsersFilters:
|
||||
async def test_list_users_database_error_propagates(self, client, superuser_token):
|
||||
"""Test that database errors propagate correctly (covers line 118-120)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.get_multi_with_total",
|
||||
"app.api.routes.admin.user_service.list_users",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -74,8 +74,8 @@ class TestAdminCreateUserErrors:
|
||||
},
|
||||
)
|
||||
|
||||
# Should get error for duplicate email
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
# Should get conflict for duplicate email
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_unexpected_error_propagates(
|
||||
@@ -83,7 +83,7 @@ class TestAdminCreateUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user creation (covers line 151-153)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.create",
|
||||
"app.api.routes.admin.user_service.create_user",
|
||||
side_effect=RuntimeError("Unexpected error"),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -135,7 +135,7 @@ class TestAdminUpdateUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user update (covers line 206-208)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.update",
|
||||
"app.api.routes.admin.user_service.update_user",
|
||||
side_effect=RuntimeError("Update failed"),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -166,7 +166,7 @@ class TestAdminDeleteUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user deletion (covers line 238-240)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.soft_delete",
|
||||
"app.api.routes.admin.user_service.soft_delete_user",
|
||||
side_effect=Exception("Delete failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -196,7 +196,7 @@ class TestAdminActivateUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user activation (covers line 282-284)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.update",
|
||||
"app.api.routes.admin.user_service.update_user",
|
||||
side_effect=Exception("Activation failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -238,7 +238,7 @@ class TestAdminDeactivateUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user deactivation (covers line 326-328)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.update",
|
||||
"app.api.routes.admin.user_service.update_user",
|
||||
side_effect=Exception("Deactivation failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -258,7 +258,7 @@ class TestAdminListOrganizationsErrors:
|
||||
async def test_list_organizations_database_error(self, client, superuser_token):
|
||||
"""Test list organizations with database error (covers line 427-456)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.get_multi_with_member_counts",
|
||||
"app.api.routes.admin.organization_service.get_multi_with_member_counts",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -299,14 +299,14 @@ class TestAdminCreateOrganizationErrors:
|
||||
},
|
||||
)
|
||||
|
||||
# Should get error for duplicate slug
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
# Should get conflict for duplicate slug
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_organization_unexpected_error(self, client, superuser_token):
|
||||
"""Test unexpected errors during organization creation (covers line 484-485)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.create",
|
||||
"app.api.routes.admin.organization_service.create_organization",
|
||||
side_effect=RuntimeError("Creation failed"),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -367,7 +367,7 @@ class TestAdminUpdateOrganizationErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.update",
|
||||
"app.api.routes.admin.organization_service.update_organization",
|
||||
side_effect=Exception("Update failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -412,7 +412,7 @@ class TestAdminDeleteOrganizationErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.remove",
|
||||
"app.api.routes.admin.organization_service.remove_organization",
|
||||
side_effect=Exception("Delete failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -456,7 +456,7 @@ class TestAdminListOrganizationMembersErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.get_organization_members",
|
||||
"app.api.routes.admin.organization_service.get_organization_members",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -531,7 +531,7 @@ class TestAdminAddOrganizationMemberErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.add_user",
|
||||
"app.api.routes.admin.organization_service.add_member",
|
||||
side_effect=Exception("Add failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -587,7 +587,7 @@ class TestAdminRemoveOrganizationMemberErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.remove_user",
|
||||
"app.api.routes.admin.organization_service.remove_member",
|
||||
side_effect=Exception("Remove failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
|
||||
@@ -19,7 +19,7 @@ class TestLoginSessionCreationFailure:
|
||||
"""Test that login succeeds even if session creation fails."""
|
||||
# Mock session creation to fail
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.create_session",
|
||||
"app.api.routes.auth.session_service.create_session",
|
||||
side_effect=Exception("Session creation failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -43,7 +43,7 @@ class TestOAuthLoginSessionCreationFailure:
|
||||
):
|
||||
"""Test OAuth login succeeds even if session creation fails."""
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.create_session",
|
||||
"app.api.routes.auth.session_service.create_session",
|
||||
side_effect=Exception("Session failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -76,7 +76,7 @@ class TestRefreshTokenSessionUpdateFailure:
|
||||
|
||||
# Mock session update to fail
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.update_refresh_token",
|
||||
"app.api.routes.auth.session_service.update_refresh_token",
|
||||
side_effect=Exception("Update failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -130,7 +130,7 @@ class TestLogoutWithNonExistentSession:
|
||||
tokens = response.json()
|
||||
|
||||
# Mock session lookup to return None
|
||||
with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None):
|
||||
with patch("app.api.routes.auth.session_service.get_by_jti", return_value=None):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
@@ -157,7 +157,7 @@ class TestLogoutUnexpectedError:
|
||||
|
||||
# Mock to raise unexpected error
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.get_by_jti",
|
||||
"app.api.routes.auth.session_service.get_by_jti",
|
||||
side_effect=Exception("Unexpected error"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -186,7 +186,7 @@ class TestLogoutAllUnexpectedError:
|
||||
|
||||
# Mock to raise database error
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
||||
"app.api.routes.auth.session_service.deactivate_all_user_sessions",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -212,7 +212,7 @@ class TestPasswordResetConfirmSessionInvalidation:
|
||||
|
||||
# Mock session invalidation to fail
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
||||
"app.api.routes.auth.session_service.deactivate_all_user_sessions",
|
||||
side_effect=Exception("Invalidation failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
|
||||
@@ -334,7 +334,7 @@ class TestPasswordResetConfirm:
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
|
||||
# Mock the database commit to raise an exception
|
||||
with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get:
|
||||
with patch("app.services.auth_service.user_repo.get_by_email") as mock_get:
|
||||
mock_get.side_effect = Exception("Database error")
|
||||
|
||||
response = await client.post(
|
||||
|
||||
@@ -12,7 +12,7 @@ These tests prevent real-world attack scenarios.
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.oauth import oauth_account
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.schemas.oauth import OAuthAccountCreate
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ class TestOAuthProviderEndpoints:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a test client
|
||||
from app.crud.oauth import oauth_client
|
||||
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -386,7 +386,7 @@ class TestOAuthProviderEndpoints:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a test client
|
||||
from app.crud.oauth import oauth_client
|
||||
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
|
||||
@@ -537,7 +537,7 @@ class TestOrganizationExceptionHandlers:
|
||||
):
|
||||
"""Test generic exception handler in get_my_organizations (covers lines 81-83)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get_user_organizations_with_details",
|
||||
"app.api.routes.organizations.organization_service.get_user_organizations_with_details",
|
||||
side_effect=Exception("Database connection lost"),
|
||||
):
|
||||
# The exception handler logs and re-raises, so we expect the exception
|
||||
@@ -554,7 +554,7 @@ class TestOrganizationExceptionHandlers:
|
||||
):
|
||||
"""Test generic exception handler in get_organization (covers lines 124-128)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get",
|
||||
"app.api.routes.organizations.organization_service.get_organization",
|
||||
side_effect=Exception("Database timeout"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Database timeout"):
|
||||
@@ -569,7 +569,7 @@ class TestOrganizationExceptionHandlers:
|
||||
):
|
||||
"""Test generic exception handler in get_organization_members (covers lines 170-172)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get_organization_members",
|
||||
"app.api.routes.organizations.organization_service.get_organization_members",
|
||||
side_effect=Exception("Connection pool exhausted"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Connection pool exhausted"):
|
||||
@@ -591,11 +591,11 @@ class TestOrganizationExceptionHandlers:
|
||||
admin_token = login_response.json()["access_token"]
|
||||
|
||||
with patch(
|
||||
"app.crud.organization.organization.get",
|
||||
"app.api.routes.organizations.organization_service.get_organization",
|
||||
return_value=test_org_with_user_admin,
|
||||
):
|
||||
with patch(
|
||||
"app.crud.organization.organization.update",
|
||||
"app.api.routes.organizations.organization_service.update_organization",
|
||||
side_effect=Exception("Write lock timeout"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Write lock timeout"):
|
||||
|
||||
@@ -11,7 +11,7 @@ These tests prevent unauthorized access and privilege escalation.
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ async def async_test_user2(async_test_db):
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
user_data = UserCreate(
|
||||
@@ -191,7 +191,7 @@ class TestRevokeSession:
|
||||
|
||||
# Verify session is deactivated
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
|
||||
revoked_session = await session_crud.get(session, id=str(session_id))
|
||||
assert revoked_session.is_active is False
|
||||
@@ -268,7 +268,7 @@ class TestCleanupExpiredSessions:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create expired and active sessions using CRUD to avoid greenlet issues
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
async with SessionLocal() as db:
|
||||
@@ -334,7 +334,7 @@ class TestCleanupExpiredSessions:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create only active sessions using CRUD
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
async with SessionLocal() as db:
|
||||
@@ -384,7 +384,7 @@ class TestSessionsAdditionalCases:
|
||||
|
||||
# Create multiple sessions
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
for i in range(5):
|
||||
@@ -431,7 +431,7 @@ class TestSessionsAdditionalCases:
|
||||
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
async with SessionLocal() as db:
|
||||
@@ -502,10 +502,10 @@ class TestSessionExceptionHandlers:
|
||||
"""Test list_sessions handles database errors (covers lines 104-106)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.crud import session as session_module
|
||||
from app.repositories import session as session_module
|
||||
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
session_module.session_repo,
|
||||
"get_user_sessions",
|
||||
side_effect=Exception("Database error"),
|
||||
):
|
||||
@@ -527,10 +527,10 @@ class TestSessionExceptionHandlers:
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud import session as session_module
|
||||
from app.repositories import session as session_module
|
||||
|
||||
# First create a session to revoke
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
@@ -550,7 +550,7 @@ class TestSessionExceptionHandlers:
|
||||
|
||||
# Mock the deactivate method to raise an exception
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
session_module.session_repo,
|
||||
"deactivate",
|
||||
side_effect=Exception("Database connection lost"),
|
||||
):
|
||||
@@ -568,10 +568,10 @@ class TestSessionExceptionHandlers:
|
||||
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.crud import session as session_module
|
||||
from app.repositories import session as session_module
|
||||
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
session_module.session_repo,
|
||||
"cleanup_expired_for_user",
|
||||
side_effect=Exception("Cleanup failed"),
|
||||
):
|
||||
|
||||
@@ -99,7 +99,7 @@ class TestUpdateCurrentUser:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error")
|
||||
"app.api.routes.users.user_service.update_user", side_effect=Exception("DB error")
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.patch(
|
||||
@@ -134,7 +134,7 @@ class TestUpdateCurrentUser:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update",
|
||||
"app.api.routes.users.user_service.update_user",
|
||||
side_effect=ValueError("Invalid value"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
@@ -224,7 +224,7 @@ class TestUpdateUserById:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid")
|
||||
"app.api.routes.users.user_service.update_user", side_effect=ValueError("Invalid")
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
await client.patch(
|
||||
@@ -241,7 +241,7 @@ class TestUpdateUserById:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected")
|
||||
"app.api.routes.users.user_service.update_user", side_effect=Exception("Unexpected")
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.patch(
|
||||
@@ -354,7 +354,7 @@ class TestDeleteUserById:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.soft_delete",
|
||||
"app.api.routes.users.user_service.soft_delete_user",
|
||||
side_effect=ValueError("Cannot delete"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
@@ -371,7 +371,7 @@ class TestDeleteUserById:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.soft_delete",
|
||||
"app.api.routes.users.user_service.soft_delete_user",
|
||||
side_effect=Exception("Unexpected"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
|
||||
@@ -46,7 +46,7 @@ async def login_user(client, email: str, password: str = "SecurePassword123!"):
|
||||
|
||||
async def create_superuser(e2e_db_session, email: str, password: str):
|
||||
"""Create a superuser directly in the database."""
|
||||
from app.crud.user import user as user_crud
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
user_in = UserCreate(
|
||||
|
||||
@@ -46,7 +46,7 @@ async def register_and_login(client, email: str, password: str = "SecurePassword
|
||||
|
||||
async def create_superuser_and_login(client, db_session):
|
||||
"""Helper to create a superuser directly in DB and login."""
|
||||
from app.crud.user import user as user_crud
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
email = f"admin-{uuid4().hex[:8]}@example.com"
|
||||
|
||||
@@ -11,7 +11,12 @@ import pytest
|
||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.repository_exceptions import (
|
||||
DuplicateEntryError,
|
||||
IntegrityConstraintError,
|
||||
InvalidInputError,
|
||||
)
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
@@ -81,7 +86,7 @@ class TestCRUDBaseGetMulti:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
|
||||
await user_crud.get_multi(session, skip=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -90,7 +95,7 @@ class TestCRUDBaseGetMulti:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
|
||||
await user_crud.get_multi(session, limit=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -99,7 +104,7 @@ class TestCRUDBaseGetMulti:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi(session, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -140,7 +145,7 @@ class TestCRUDBaseCreate:
|
||||
last_name="Duplicate",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
with pytest.raises(DuplicateEntryError, match="already exists"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -165,7 +170,7 @@ class TestCRUDBaseCreate:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
with pytest.raises(DuplicateEntryError, match="Database integrity error"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -244,7 +249,7 @@ class TestCRUDBaseUpdate:
|
||||
|
||||
# Create another user
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
|
||||
user2_data = UserCreate(
|
||||
email="user2@example.com",
|
||||
@@ -268,7 +273,7 @@ class TestCRUDBaseUpdate:
|
||||
):
|
||||
update_data = UserUpdate(email=async_test_user.email)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
with pytest.raises(DuplicateEntryError, match="already exists"):
|
||||
await user_crud.update(
|
||||
session, db_obj=user2_obj, obj_in=update_data
|
||||
)
|
||||
@@ -302,7 +307,7 @@ class TestCRUDBaseUpdate:
|
||||
"statement", {}, Exception("constraint failed")
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
with pytest.raises(IntegrityConstraintError, match="Database integrity error"):
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
@@ -322,7 +327,7 @@ class TestCRUDBaseUpdate:
|
||||
"statement", {}, Exception("connection error")
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
@@ -403,7 +408,7 @@ class TestCRUDBaseRemove:
|
||||
),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot delete.*referenced by other records"
|
||||
IntegrityConstraintError, match="Cannot delete.*referenced by other records"
|
||||
):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
@@ -442,7 +447,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, skip=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -451,7 +456,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, limit=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -460,7 +465,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -827,7 +832,7 @@ class TestCRUDBasePaginationValidation:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -836,7 +841,7 @@ class TestCRUDBasePaginationValidation:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -845,7 +850,7 @@ class TestCRUDBasePaginationValidation:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -899,7 +904,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization (which doesn't have deleted_at)
|
||||
from app.crud.organization import organization as org_crud
|
||||
from app.repositories.organization import organization_repo as org_crud
|
||||
from app.models.organization import Organization
|
||||
|
||||
async with SessionLocal() as session:
|
||||
@@ -910,7 +915,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
|
||||
# Try to soft delete organization (should fail)
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="does not have a deleted_at column"):
|
||||
with pytest.raises(InvalidInputError, match="does not have a deleted_at column"):
|
||||
await org_crud.soft_delete(session, id=str(org_id))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -919,7 +924,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization (which doesn't have deleted_at)
|
||||
from app.crud.organization import organization as org_crud
|
||||
from app.repositories.organization import organization_repo as org_crud
|
||||
from app.models.organization import Organization
|
||||
|
||||
async with SessionLocal() as session:
|
||||
@@ -930,7 +935,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
|
||||
# Try to restore organization (should fail)
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="does not have a deleted_at column"):
|
||||
with pytest.raises(InvalidInputError, match="does not have a deleted_at column"):
|
||||
await org_crud.restore(session, id=str(org_id))
|
||||
|
||||
|
||||
@@ -950,7 +955,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session for the user
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
|
||||
async with SessionLocal() as session:
|
||||
@@ -989,7 +994,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create multiple sessions for the user
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
|
||||
async with SessionLocal() as session:
|
||||
@@ -10,7 +10,8 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.exc import DataError, OperationalError
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.repository_exceptions import IntegrityConstraintError
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
@@ -119,7 +120,7 @@ class TestBaseCRUDUpdateFailures:
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
@@ -141,7 +142,7 @@ class TestBaseCRUDUpdateFailures:
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
@@ -7,7 +7,10 @@ from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.oauth import oauth_account, oauth_client, oauth_state
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
||||
|
||||
|
||||
@@ -60,7 +63,7 @@ class TestOAuthAccountCRUD:
|
||||
|
||||
# SQLite returns different error message than PostgreSQL
|
||||
with pytest.raises(
|
||||
ValueError, match="(already linked|UNIQUE constraint failed)"
|
||||
DuplicateEntryError, match="(already linked|UNIQUE constraint failed|Failed to create)"
|
||||
):
|
||||
await oauth_account.create_account(session, obj_in=account_data2)
|
||||
|
||||
@@ -256,13 +259,13 @@ class TestOAuthAccountCRUD:
|
||||
updated = await oauth_account.update_tokens(
|
||||
session,
|
||||
account=account,
|
||||
access_token_encrypted="new_access_token",
|
||||
refresh_token_encrypted="new_refresh_token",
|
||||
access_token="new_access_token",
|
||||
refresh_token="new_refresh_token",
|
||||
token_expires_at=new_expires,
|
||||
)
|
||||
|
||||
assert updated.access_token_encrypted == "new_access_token"
|
||||
assert updated.refresh_token_encrypted == "new_refresh_token"
|
||||
assert updated.access_token == "new_access_token"
|
||||
assert updated.refresh_token == "new_refresh_token"
|
||||
|
||||
|
||||
class TestOAuthStateCRUD:
|
||||
@@ -9,7 +9,8 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||
from app.repositories.organization import organization_repo as organization_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.schemas.organizations import OrganizationCreate
|
||||
@@ -87,7 +88,7 @@ class TestCreate:
|
||||
# Try to create second with same slug
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug")
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
with pytest.raises(DuplicateEntryError, match="already exists"):
|
||||
await organization_crud.create(session, obj_in=org_in)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -295,7 +296,7 @@ class TestAddUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="already a member"):
|
||||
with pytest.raises(DuplicateEntryError, match="already a member"):
|
||||
await organization_crud.add_user(
|
||||
session, organization_id=org_id, user_id=async_test_user.id
|
||||
)
|
||||
@@ -972,7 +973,7 @@ class TestOrganizationExceptionHandlers:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
org_in = OrganizationCreate(name="Test", slug="test")
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
with pytest.raises(IntegrityConstraintError, match="Database integrity error"):
|
||||
await organization_crud.create(session, obj_in=org_in)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1058,7 +1059,7 @@ class TestOrganizationExceptionHandlers:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(
|
||||
ValueError, match="Failed to add user to organization"
|
||||
IntegrityConstraintError, match="Failed to add user to organization"
|
||||
):
|
||||
await organization_crud.add_user(
|
||||
session,
|
||||
@@ -8,7 +8,8 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.repository_exceptions import InvalidInputError
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
@@ -503,7 +504,7 @@ class TestCleanupExpiredForUser:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Invalid user ID format"):
|
||||
with pytest.raises(InvalidInputError, match="Invalid user ID format"):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
session, user_id="not-a-valid-uuid"
|
||||
)
|
||||
@@ -10,7 +10,8 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.repository_exceptions import IntegrityConstraintError
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
@@ -102,7 +103,7 @@ class TestSessionCRUDCreateSessionFailures:
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
with pytest.raises(IntegrityConstraintError, match="Failed to create session"):
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
@@ -133,7 +134,7 @@ class TestSessionCRUDCreateSessionFailures:
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
with pytest.raises(IntegrityConstraintError, match="Failed to create session"):
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
@@ -5,7 +5,8 @@ Comprehensive tests for async user CRUD operations.
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
@@ -93,7 +94,7 @@ class TestCreate:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(DuplicateEntryError) as exc_info:
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
assert "already exists" in str(exc_info.value).lower()
|
||||
@@ -330,7 +331,7 @@ class TestGetMultiWithTotal:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(InvalidInputError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
||||
|
||||
assert "skip must be non-negative" in str(exc_info.value)
|
||||
@@ -341,7 +342,7 @@ class TestGetMultiWithTotal:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(InvalidInputError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
||||
|
||||
assert "limit must be non-negative" in str(exc_info.value)
|
||||
@@ -352,7 +353,7 @@ class TestGetMultiWithTotal:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(InvalidInputError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
assert "Maximum limit is 1000" in str(exc_info.value)
|
||||
@@ -10,6 +10,7 @@ from app.core.auth import (
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
from app.core.exceptions import DuplicateError
|
||||
from app.models.user import User
|
||||
from app.schemas.users import Token, UserCreate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
@@ -152,9 +153,9 @@ class TestAuthServiceUserCreation:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# Should raise AuthenticationError
|
||||
# Should raise DuplicateError for duplicate email
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(AuthenticationError):
|
||||
with pytest.raises(DuplicateError):
|
||||
await AuthService.create_user(db=session, user_data=user_data)
|
||||
|
||||
|
||||
|
||||
@@ -269,18 +269,18 @@ class TestClientValidation:
|
||||
async def test_validate_client_legacy_sha256_hash(
|
||||
self, db, confidential_client_legacy_hash
|
||||
):
|
||||
"""Test validating a client with legacy SHA-256 hash (backward compatibility)."""
|
||||
"""Test that legacy SHA-256 hash is rejected with clear error message."""
|
||||
client, secret = confidential_client_legacy_hash
|
||||
validated = await service.validate_client(db, client.client_id, secret)
|
||||
assert validated.client_id == client.client_id
|
||||
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
|
||||
await service.validate_client(db, client.client_id, secret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_client_legacy_sha256_wrong_secret(
|
||||
self, db, confidential_client_legacy_hash
|
||||
):
|
||||
"""Test legacy SHA-256 client rejects wrong secret."""
|
||||
"""Test that legacy SHA-256 client with wrong secret is rejected."""
|
||||
client, _ = confidential_client_legacy_hash
|
||||
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
|
||||
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
|
||||
await service.validate_client(db, client.client_id, "wrong_secret")
|
||||
|
||||
def test_validate_redirect_uri_success(self, public_client):
|
||||
|
||||
@@ -11,7 +11,8 @@ from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.crud.oauth import oauth_account, oauth_state
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
|
||||
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService
|
||||
|
||||
|
||||
447
backend/tests/services/test_organization_service.py
Normal file
447
backend/tests/services/test_organization_service.py
Normal file
@@ -0,0 +1,447 @@
|
||||
# tests/services/test_organization_service.py
|
||||
"""Tests for the OrganizationService class."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
|
||||
from app.services.organization_service import OrganizationService, organization_service
|
||||
|
||||
|
||||
def _make_org_create(name=None, slug=None) -> OrganizationCreate:
|
||||
"""Helper to create an OrganizationCreate schema with unique defaults."""
|
||||
unique = uuid.uuid4().hex[:8]
|
||||
return OrganizationCreate(
|
||||
name=name or f"Test Org {unique}",
|
||||
slug=slug or f"test-org-{unique}",
|
||||
description="A test organization",
|
||||
is_active=True,
|
||||
settings={},
|
||||
)
|
||||
|
||||
|
||||
class TestGetOrganization:
|
||||
"""Tests for OrganizationService.get_organization method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_found(self, async_test_db, async_test_user):
|
||||
"""Test getting an existing organization by ID returns the org."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_service.get_organization(
|
||||
session, str(created.id)
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == created.id
|
||||
assert result.slug == created.slug
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_not_found(self, async_test_db):
|
||||
"""Test getting a non-existent organization raises NotFoundError."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(NotFoundError):
|
||||
await organization_service.get_organization(
|
||||
session, str(uuid.uuid4())
|
||||
)
|
||||
|
||||
|
||||
class TestCreateOrganization:
|
||||
"""Tests for OrganizationService.create_organization method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_organization(self, async_test_db, async_test_user):
|
||||
"""Test creating a new organization returns the created org with correct fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_org_create()
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_service.create_organization(
|
||||
session, obj_in=obj_in
|
||||
)
|
||||
assert result is not None
|
||||
assert result.name == obj_in.name
|
||||
assert result.slug == obj_in.slug
|
||||
assert result.description == obj_in.description
|
||||
assert result.is_active is True
|
||||
|
||||
|
||||
class TestUpdateOrganization:
|
||||
"""Tests for OrganizationService.update_organization method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization(self, async_test_db, async_test_user):
|
||||
"""Test updating an organization name."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = await organization_service.get_organization(session, str(created.id))
|
||||
updated = await organization_service.update_organization(
|
||||
session,
|
||||
org=org,
|
||||
obj_in=OrganizationUpdate(name="Updated Org Name"),
|
||||
)
|
||||
assert updated.name == "Updated Org Name"
|
||||
assert updated.id == created.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test updating an organization using a dict."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = await organization_service.get_organization(session, str(created.id))
|
||||
updated = await organization_service.update_organization(
|
||||
session,
|
||||
org=org,
|
||||
obj_in={"description": "Updated description"},
|
||||
)
|
||||
assert updated.description == "Updated description"
|
||||
|
||||
|
||||
class TestRemoveOrganization:
|
||||
"""Tests for OrganizationService.remove_organization method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_organization(self, async_test_db, async_test_user):
|
||||
"""Test permanently deleting an organization."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
org_id = str(created.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await organization_service.remove_organization(session, org_id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(NotFoundError):
|
||||
await organization_service.get_organization(session, org_id)
|
||||
|
||||
|
||||
class TestGetMemberCount:
|
||||
"""Tests for OrganizationService.get_member_count method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_member_count_empty(self, async_test_db, async_test_user):
|
||||
"""Test member count for org with no members is zero."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await organization_service.get_member_count(
|
||||
session, organization_id=created.id
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_member_count_with_member(self, async_test_db, async_test_user):
|
||||
"""Test member count increases after adding a member."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await organization_service.add_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await organization_service.get_member_count(
|
||||
session, organization_id=created.id
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestGetMultiWithMemberCounts:
|
||||
"""Tests for OrganizationService.get_multi_with_member_counts method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_member_counts(self, async_test_db, async_test_user):
|
||||
"""Test listing organizations with member counts returns tuple."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, count = await organization_service.get_multi_with_member_counts(
|
||||
session, skip=0, limit=10
|
||||
)
|
||||
assert isinstance(orgs, list)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_member_counts_search(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test listing organizations with a search filter."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
unique = uuid.uuid4().hex[:8]
|
||||
org_name = f"Searchable Org {unique}"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await organization_service.create_organization(
|
||||
session,
|
||||
obj_in=OrganizationCreate(
|
||||
name=org_name,
|
||||
slug=f"searchable-org-{unique}",
|
||||
is_active=True,
|
||||
settings={},
|
||||
),
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, count = await organization_service.get_multi_with_member_counts(
|
||||
session, skip=0, limit=10, search=f"Searchable Org {unique}"
|
||||
)
|
||||
assert count >= 1
|
||||
# Each element is a dict with key "organization" (an Organization obj) and "member_count"
|
||||
names = [o["organization"].name for o in orgs]
|
||||
assert org_name in names
|
||||
|
||||
|
||||
class TestGetUserOrganizationsWithDetails:
|
||||
"""Tests for OrganizationService.get_user_organizations_with_details method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_organizations_with_details(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test getting organizations for a user returns list of dicts."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
await organization_service.add_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs = await organization_service.get_user_organizations_with_details(
|
||||
session, user_id=async_test_user.id
|
||||
)
|
||||
assert isinstance(orgs, list)
|
||||
assert len(orgs) >= 1
|
||||
|
||||
|
||||
class TestGetOrganizationMembers:
|
||||
"""Tests for OrganizationService.get_organization_members method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_members(self, async_test_db, async_test_user):
|
||||
"""Test getting organization members returns paginated results."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
await organization_service.add_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
members, count = await organization_service.get_organization_members(
|
||||
session, organization_id=created.id, skip=0, limit=10
|
||||
)
|
||||
assert isinstance(members, list)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 1
|
||||
|
||||
|
||||
class TestAddMember:
|
||||
"""Tests for OrganizationService.add_member method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_member_default_role(self, async_test_db, async_test_user):
|
||||
"""Test adding a user to an org with default MEMBER role."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
membership = await organization_service.add_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
)
|
||||
assert membership is not None
|
||||
assert membership.user_id == async_test_user.id
|
||||
assert membership.organization_id == created.id
|
||||
assert membership.role == OrganizationRole.MEMBER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_member_admin_role(self, async_test_db, async_test_user):
|
||||
"""Test adding a user to an org with ADMIN role."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
membership = await organization_service.add_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
)
|
||||
assert membership.role == OrganizationRole.ADMIN
|
||||
|
||||
|
||||
class TestRemoveMember:
|
||||
"""Tests for OrganizationService.remove_member method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_member(self, async_test_db, async_test_user):
|
||||
"""Test removing a member from an org returns True."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
await organization_service.add_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
removed = await organization_service.remove_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
)
|
||||
assert removed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_member_not_found(self, async_test_db, async_test_user):
|
||||
"""Test removing a non-member returns False."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
removed = await organization_service.remove_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
)
|
||||
assert removed is False
|
||||
|
||||
|
||||
class TestGetUserRoleInOrg:
|
||||
"""Tests for OrganizationService.get_user_role_in_org method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_role_in_org(self, async_test_db, async_test_user):
|
||||
"""Test getting a user's role in an org they belong to."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
await organization_service.add_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
role = await organization_service.get_user_role_in_org(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=created.id,
|
||||
)
|
||||
assert role == OrganizationRole.MEMBER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_role_in_org_not_member(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test getting role for a user not in the org returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
role = await organization_service.get_user_role_in_org(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=created.id,
|
||||
)
|
||||
assert role is None
|
||||
|
||||
|
||||
class TestGetOrgDistribution:
|
||||
"""Tests for OrganizationService.get_org_distribution method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_distribution_empty(self, async_test_db):
|
||||
"""Test org distribution with no memberships returns empty list."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_service.get_org_distribution(session, limit=6)
|
||||
assert isinstance(result, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_distribution_with_members(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test org distribution returns org name and member count."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await organization_service.create_organization(
|
||||
session, obj_in=_make_org_create()
|
||||
)
|
||||
await organization_service.add_member(
|
||||
session,
|
||||
organization_id=created.id,
|
||||
user_id=async_test_user.id,
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_service.get_org_distribution(session, limit=6)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) >= 1
|
||||
entry = result[0]
|
||||
assert "name" in entry
|
||||
assert "value" in entry
|
||||
assert entry["value"] >= 1
|
||||
292
backend/tests/services/test_session_service.py
Normal file
292
backend/tests/services/test_session_service.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# tests/services/test_session_service.py
|
||||
"""Tests for the SessionService class."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.schemas.sessions import SessionCreate
|
||||
from app.services.session_service import SessionService, session_service
|
||||
|
||||
|
||||
def _make_session_create(user_id, jti=None) -> SessionCreate:
|
||||
"""Helper to build a SessionCreate with sensible defaults."""
|
||||
now = datetime.now(UTC)
|
||||
return SessionCreate(
|
||||
user_id=user_id,
|
||||
refresh_token_jti=jti or str(uuid.uuid4()),
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="pytest/test",
|
||||
device_name="Test Device",
|
||||
device_id="test-device-id",
|
||||
last_used_at=now,
|
||||
expires_at=now + timedelta(days=7),
|
||||
location_city="TestCity",
|
||||
location_country="TestCountry",
|
||||
)
|
||||
|
||||
|
||||
class TestCreateSession:
|
||||
"""Tests for SessionService.create_session method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self, async_test_db, async_test_user):
|
||||
"""Test creating a session returns a UserSession with correct fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.create_session(session, obj_in=obj_in)
|
||||
assert result is not None
|
||||
assert result.user_id == async_test_user.id
|
||||
assert result.refresh_token_jti == obj_in.refresh_token_jti
|
||||
assert result.is_active is True
|
||||
assert result.ip_address == "127.0.0.1"
|
||||
|
||||
|
||||
class TestGetSession:
|
||||
"""Tests for SessionService.get_session method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_found(self, async_test_db, async_test_user):
|
||||
"""Test getting a session by ID returns the session."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_session(session, str(created.id))
|
||||
assert result is not None
|
||||
assert result.id == created.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_not_found(self, async_test_db):
|
||||
"""Test getting a non-existent session returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_session(session, str(uuid.uuid4()))
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetUserSessions:
|
||||
"""Tests for SessionService.get_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
|
||||
"""Test getting active sessions for a user returns only active sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sessions = await session_service.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=True
|
||||
)
|
||||
assert isinstance(sessions, list)
|
||||
assert len(sessions) >= 1
|
||||
for s in sessions:
|
||||
assert s.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
|
||||
"""Test getting all sessions (active and inactive) for a user."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
await session_service.deactivate(session, session_id=str(created.id))
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sessions = await session_service.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=False
|
||||
)
|
||||
assert isinstance(sessions, list)
|
||||
assert len(sessions) >= 1
|
||||
|
||||
|
||||
class TestGetActiveByJti:
|
||||
"""Tests for SessionService.get_active_by_jti method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_found(self, async_test_db, async_test_user):
|
||||
"""Test getting an active session by JTI returns the session."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
jti = str(uuid.uuid4())
|
||||
obj_in = _make_session_create(async_test_user.id, jti=jti)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_active_by_jti(session, jti=jti)
|
||||
assert result is not None
|
||||
assert result.refresh_token_jti == jti
|
||||
assert result.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_not_found(self, async_test_db):
|
||||
"""Test getting an active session by non-existent JTI returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_active_by_jti(
|
||||
session, jti=str(uuid.uuid4())
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetByJti:
|
||||
"""Tests for SessionService.get_by_jti method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_active(self, async_test_db, async_test_user):
|
||||
"""Test getting a session (active or inactive) by JTI."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
jti = str(uuid.uuid4())
|
||||
obj_in = _make_session_create(async_test_user.id, jti=jti)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_by_jti(session, jti=jti)
|
||||
assert result is not None
|
||||
assert result.refresh_token_jti == jti
|
||||
|
||||
|
||||
class TestDeactivate:
|
||||
"""Tests for SessionService.deactivate method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_session(self, async_test_db, async_test_user):
|
||||
"""Test deactivating a session sets is_active to False."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
session_id = str(created.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
deactivated = await session_service.deactivate(
|
||||
session, session_id=session_id
|
||||
)
|
||||
assert deactivated is not None
|
||||
assert deactivated.is_active is False
|
||||
|
||||
|
||||
class TestDeactivateAllUserSessions:
|
||||
"""Tests for SessionService.deactivate_all_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_user_sessions(self, async_test_db, async_test_user):
|
||||
"""Test deactivating all sessions for a user returns count deactivated."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(
|
||||
session, obj_in=_make_session_create(async_test_user.id)
|
||||
)
|
||||
await session_service.create_session(
|
||||
session, obj_in=_make_session_create(async_test_user.id)
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_service.deactivate_all_user_sessions(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count >= 2
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_sessions = await session_service.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=True
|
||||
)
|
||||
assert len(active_sessions) == 0
|
||||
|
||||
|
||||
class TestUpdateRefreshToken:
|
||||
"""Tests for SessionService.update_refresh_token method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token(self, async_test_db, async_test_user):
|
||||
"""Test rotating a session's refresh token updates JTI and expiry."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
session_id = str(created.id)
|
||||
|
||||
new_jti = str(uuid.uuid4())
|
||||
new_expires_at = datetime.now(UTC) + timedelta(days=14)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_service.get_session(session, session_id)
|
||||
updated = await session_service.update_refresh_token(
|
||||
session,
|
||||
session=result,
|
||||
new_jti=new_jti,
|
||||
new_expires_at=new_expires_at,
|
||||
)
|
||||
assert updated.refresh_token_jti == new_jti
|
||||
|
||||
|
||||
class TestCleanupExpiredForUser:
|
||||
"""Tests for SessionService.cleanup_expired_for_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user(self, async_test_db, async_test_user):
|
||||
"""Test cleaning up expired inactive sessions returns count removed."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
now = datetime.now(UTC)
|
||||
# Create a session that is already expired
|
||||
obj_in = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid.uuid4()),
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="pytest/test",
|
||||
last_used_at=now - timedelta(days=8),
|
||||
expires_at=now - timedelta(days=1),
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
created = await session_service.create_session(session, obj_in=obj_in)
|
||||
session_id = str(created.id)
|
||||
|
||||
# Deactivate it so it qualifies for cleanup (requires is_active=False AND expired)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.deactivate(session, session_id=session_id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_service.cleanup_expired_for_user(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 1
|
||||
|
||||
|
||||
class TestGetAllSessions:
|
||||
"""Tests for SessionService.get_all_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_sessions(self, async_test_db, async_test_user):
|
||||
"""Test getting all sessions with pagination returns tuple of list and count."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
obj_in = _make_session_create(async_test_user.id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await session_service.create_session(session, obj_in=obj_in)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sessions, count = await session_service.get_all_sessions(
|
||||
session, skip=0, limit=10, active_only=True, with_user=False
|
||||
)
|
||||
assert isinstance(sessions, list)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 1
|
||||
assert len(sessions) >= 1
|
||||
214
backend/tests/services/test_user_service.py
Normal file
214
backend/tests/services/test_user_service.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# tests/services/test_user_service.py
|
||||
"""Tests for the UserService class."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.services.user_service import UserService, user_service
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Tests for UserService.get_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_found(self, async_test_db, async_test_user):
|
||||
"""Test getting an existing user by ID returns the user."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_service.get_user(session, str(async_test_user.id))
|
||||
assert result is not None
|
||||
assert result.id == async_test_user.id
|
||||
assert result.email == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_not_found(self, async_test_db):
|
||||
"""Test getting a non-existent user raises NotFoundError."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(NotFoundError):
|
||||
await user_service.get_user(session, non_existent_id)
|
||||
|
||||
|
||||
class TestGetByEmail:
|
||||
"""Tests for UserService.get_by_email method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_found(self, async_test_db, async_test_user):
|
||||
"""Test getting an existing user by email returns the user."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_service.get_by_email(session, async_test_user.email)
|
||||
assert result is not None
|
||||
assert result.id == async_test_user.id
|
||||
assert result.email == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_not_found(self, async_test_db):
|
||||
"""Test getting a user by non-existent email returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_service.get_by_email(session, "nonexistent@example.com")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreateUser:
|
||||
"""Tests for UserService.create_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user(self, async_test_db):
|
||||
"""Test creating a new user with valid data."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
unique_email = f"test_{uuid.uuid4()}@example.com"
|
||||
user_data = UserCreate(
|
||||
email=unique_email,
|
||||
password="TestPassword123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_service.create_user(session, user_data)
|
||||
assert result is not None
|
||||
assert result.email == unique_email
|
||||
assert result.first_name == "New"
|
||||
assert result.last_name == "User"
|
||||
assert result.is_active is True
|
||||
|
||||
|
||||
class TestUpdateUser:
|
||||
"""Tests for UserService.update_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user(self, async_test_db, async_test_user):
|
||||
"""Test updating a user's first_name."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_service.get_user(session, str(async_test_user.id))
|
||||
updated = await user_service.update_user(
|
||||
session,
|
||||
user=user,
|
||||
obj_in=UserUpdate(first_name="Updated"),
|
||||
)
|
||||
assert updated.first_name == "Updated"
|
||||
assert updated.id == async_test_user.id
|
||||
|
||||
|
||||
class TestSoftDeleteUser:
|
||||
"""Tests for UserService.soft_delete_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_user(self, async_test_db, async_test_user):
|
||||
"""Test soft-deleting a user sets deleted_at."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await user_service.soft_delete_user(session, str(async_test_user.id))
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
assert user is not None
|
||||
assert user.deleted_at is not None
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
"""Tests for UserService.list_users method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users(self, async_test_db, async_test_user):
|
||||
"""Test listing users with pagination returns correct results."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, count = await user_service.list_users(session, skip=0, limit=10)
|
||||
assert isinstance(users, list)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 1
|
||||
assert len(users) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_with_search(self, async_test_db, async_test_user):
|
||||
"""Test listing users with email fragment search returns matching users."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
# Search by partial email fragment of the test user
|
||||
email_fragment = async_test_user.email.split("@")[0]
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, count = await user_service.list_users(
|
||||
session, skip=0, limit=10, search=email_fragment
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
assert count >= 1
|
||||
emails = [u.email for u in users]
|
||||
assert async_test_user.email in emails
|
||||
|
||||
|
||||
class TestBulkUpdateStatus:
|
||||
"""Tests for UserService.bulk_update_status method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status(self, async_test_db, async_test_user):
|
||||
"""Test bulk activating users returns correct count."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_service.bulk_update_status(
|
||||
session,
|
||||
user_ids=[async_test_user.id],
|
||||
is_active=True,
|
||||
)
|
||||
assert count >= 1
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
assert user is not None
|
||||
assert user.is_active is True
|
||||
|
||||
|
||||
class TestBulkSoftDelete:
|
||||
"""Tests for UserService.bulk_soft_delete method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete(self, async_test_db, async_test_user):
|
||||
"""Test bulk soft-deleting users returns correct count."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_service.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[async_test_user.id],
|
||||
)
|
||||
assert count >= 1
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
assert user is not None
|
||||
assert user.deleted_at is not None
|
||||
|
||||
|
||||
class TestGetStats:
|
||||
"""Tests for UserService.get_stats method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(self, async_test_db, async_test_user):
|
||||
"""Test get_stats returns dict with expected keys and correct counts."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
stats = await user_service.get_stats(session)
|
||||
assert "total_users" in stats
|
||||
assert "active_count" in stats
|
||||
assert "inactive_count" in stats
|
||||
assert "all_users" in stats
|
||||
assert stats["total_users"] >= 1
|
||||
assert stats["active_count"] >= 1
|
||||
assert isinstance(stats["all_users"], list)
|
||||
assert len(stats["all_users"]) >= 1
|
||||
Reference in New Issue
Block a user