Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff

- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest).
- Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting.
- Updated `requirements.txt` to include Ruff and remove replaced tools.
- Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
2025-11-10 11:55:15 +01:00
parent a5c671c133
commit c589b565f0
86 changed files with 4572 additions and 3956 deletions

View File

@@ -1,15 +1,16 @@
# tests/api/dependencies/test_auth_dependencies.py
import pytest
import pytest_asyncio
import uuid
from unittest.mock import patch
import pytest
import pytest_asyncio
from fastapi import HTTPException
from app.api.dependencies.auth import (
get_current_user,
get_current_active_user,
get_current_superuser,
get_optional_current_user
get_current_user,
get_optional_current_user,
)
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
from app.models.user import User
@@ -24,7 +25,7 @@ def mock_token():
@pytest_asyncio.fixture
async def async_mock_user(async_test_db):
"""Async fixture to create and return a mock User instance."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
mock_user = User(
id=uuid.uuid4(),
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
"""Tests for get_current_user dependency"""
@pytest.mark.asyncio
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token):
async def test_get_current_user_success(
self, async_test_db, async_mock_user, mock_token
):
"""Test successfully getting the current user"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return user_id that matches our mock_user
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
"""Test when the token contains a user ID that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return a non-existent user ID
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = nonexistent_id
# Should raise HTTPException with 404 status
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
assert "User not found" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
async def test_get_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test when the user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Should raise HTTPException with 403 status
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_expired_token(self, async_test_db, mock_token):
"""Test with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Should raise HTTPException with 401 status
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
"""Test with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Should raise HTTPException with 401 status
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
"""Tests for get_optional_current_user dependency"""
@pytest.mark.asyncio
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token):
async def test_get_optional_current_user_with_token(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user with a valid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
@pytest.mark.asyncio
async def test_get_optional_current_user_no_token(self, async_test_db):
"""Test getting optional user with no token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Call the dependency with no token
user = await get_optional_current_user(db=session, token=None)
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token):
async def test_get_optional_current_user_invalid_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Call the dependency
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token):
async def test_get_optional_current_user_expired_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Call the dependency
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
async def test_get_optional_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user when user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency

View File

@@ -1,13 +1,12 @@
# tests/api/routes/test_health.py
from datetime import datetime
from unittest.mock import patch
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi import status
from fastapi.testclient import TestClient
from datetime import datetime
from sqlalchemy.exc import OperationalError
from app.main import app
from app.core.database import get_db
@pytest.fixture
@@ -121,7 +120,10 @@ class TestHealthEndpoint:
response = client.get("/health")
# Should succeed without authentication
assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE]
assert response.status_code in [
status.HTTP_200_OK,
status.HTTP_503_SERVICE_UNAVAILABLE,
]
def test_health_check_idempotent(self, client):
"""Test that multiple health checks return consistent results"""
@@ -142,7 +144,10 @@ class TestHealthEndpoint:
assert data1["environment"] == data2["environment"]
# Same database check status
assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"]
assert (
data1["checks"]["database"]["status"]
== data2["checks"]["database"]["status"]
)
def test_health_check_content_type(self, client):
"""Test that health check returns JSON content type"""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
"""
Tests for authentication endpoints.
"""
import pytest
import pytest_asyncio
from fastapi import status
@@ -19,8 +20,8 @@ class TestRegisterEndpoint:
"email": "newuser@example.com",
"password": "NewPassword123!",
"first_name": "New",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_201_CREATED
@@ -36,8 +37,8 @@ class TestRegisterEndpoint:
"email": async_test_user.email,
"password": "TestPassword123!",
"first_name": "Test",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -51,8 +52,8 @@ class TestRegisterEndpoint:
"email": "test@example.com",
"password": "weak",
"first_name": "Test",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -66,10 +67,7 @@ class TestLoginEndpoint:
"""Test successful login."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_200_OK
@@ -82,10 +80,7 @@ class TestLoginEndpoint:
"""Test login with invalid password."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "WrongPassword123!"
}
json={"email": "testuser@example.com", "password": "WrongPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -95,10 +90,7 @@ class TestLoginEndpoint:
"""Test login with non-existent user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "nonexistent@example.com",
"password": "TestPassword123!"
}
json={"email": "nonexistent@example.com", "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -106,27 +98,25 @@ class TestLoginEndpoint:
@pytest.mark.asyncio
async def test_login_inactive_user(self, client, async_test_db):
"""Test login with inactive user."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
from app.models.user import User
from app.core.auth import get_password_hash
from app.models.user import User
inactive_user = User(
email="inactive@example.com",
password_hash=get_password_hash("TestPassword123!"),
first_name="Inactive",
last_name="User",
is_active=False
is_active=False,
)
session.add(inactive_user)
await session.commit()
response = await client.post(
"/api/v1/auth/login",
json={
"email": "inactive@example.com",
"password": "TestPassword123!"
}
json={"email": "inactive@example.com", "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -140,10 +130,7 @@ class TestRefreshTokenEndpoint:
"""Get a refresh token for testing."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
return response.json()["refresh_token"]
@@ -151,8 +138,7 @@ class TestRefreshTokenEndpoint:
async def test_refresh_token_success(self, client, refresh_token):
"""Test successful token refresh."""
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
assert response.status_code == status.HTTP_200_OK
@@ -164,8 +150,7 @@ class TestRefreshTokenEndpoint:
async def test_refresh_token_invalid(self, client):
"""Test refresh with invalid token."""
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid.token.here"}
"/api/v1/auth/refresh", json={"refresh_token": "invalid.token.here"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -179,13 +164,13 @@ class TestLogoutEndpoint:
"""Get tokens for testing."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
data = response.json()
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
return {
"access_token": data["access_token"],
"refresh_token": data["refresh_token"],
}
@pytest.mark.asyncio
async def test_logout_success(self, client, tokens):
@@ -193,7 +178,7 @@ class TestLogoutEndpoint:
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]}
json={"refresh_token": tokens["refresh_token"]},
)
assert response.status_code == status.HTTP_200_OK
@@ -202,8 +187,7 @@ class TestLogoutEndpoint:
async def test_logout_without_auth(self, client):
"""Test logout without authentication."""
response = await client.post(
"/api/v1/auth/logout",
json={"refresh_token": "some.token"}
"/api/v1/auth/logout", json={"refresh_token": "some.token"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -215,8 +199,7 @@ class TestPasswordResetRequest:
async def test_password_reset_request_success(self, client, async_test_user):
"""Test password reset request with existing user."""
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
"/api/v1/auth/password-reset/request", json={"email": async_test_user.email}
)
assert response.status_code == status.HTTP_200_OK
@@ -228,7 +211,7 @@ class TestPasswordResetRequest:
"""Test password reset request with non-existent email."""
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": "nonexistent@example.com"}
json={"email": "nonexistent@example.com"},
)
assert response.status_code == status.HTTP_200_OK
@@ -244,10 +227,7 @@ class TestPasswordResetConfirm:
"""Test password reset with invalid token."""
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": "invalid.token.here",
"new_password": "NewPassword123!"
}
json={"token": "invalid.token.here", "new_password": "NewPassword123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -261,20 +241,20 @@ class TestLogoutAll:
"""Get tokens for testing."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
data = response.json()
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
return {
"access_token": data["access_token"],
"refresh_token": data["refresh_token"],
}
@pytest.mark.asyncio
async def test_logout_all_success(self, client, tokens):
"""Test logout from all devices."""
response = await client.post(
"/api/v1/auth/logout-all",
headers={"Authorization": f"Bearer {tokens['access_token']}"}
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -298,10 +278,7 @@ class TestOAuthLogin:
"""Test successful OAuth login."""
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": "testuser@example.com",
"password": "TestPassword123!"
}
data={"username": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_200_OK
@@ -315,10 +292,7 @@ class TestOAuthLogin:
"""Test OAuth login with invalid credentials."""
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": "testuser@example.com",
"password": "WrongPassword"
}
data={"username": "testuser@example.com", "password": "WrongPassword"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -1,15 +1,16 @@
# tests/api/dependencies/test_auth_dependencies.py
import pytest
import pytest_asyncio
import uuid
from unittest.mock import patch
import pytest
import pytest_asyncio
from fastapi import HTTPException
from app.api.dependencies.auth import (
get_current_user,
get_current_active_user,
get_current_superuser,
get_optional_current_user
get_current_user,
get_optional_current_user,
)
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
from app.models.user import User
@@ -24,7 +25,7 @@ def mock_token():
@pytest_asyncio.fixture
async def async_mock_user(async_test_db):
"""Async fixture to create and return a mock User instance."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
mock_user = User(
id=uuid.uuid4(),
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
"""Tests for get_current_user dependency"""
@pytest.mark.asyncio
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token):
async def test_get_current_user_success(
self, async_test_db, async_mock_user, mock_token
):
"""Test successfully getting the current user"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return user_id that matches our mock_user
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
"""Test when the token contains a user ID that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return a non-existent user ID
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = nonexistent_id
# Should raise HTTPException with 404 status
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
assert "User not found" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
async def test_get_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test when the user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Should raise HTTPException with 403 status
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_expired_token(self, async_test_db, mock_token):
"""Test with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Should raise HTTPException with 401 status
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
"""Test with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Should raise HTTPException with 401 status
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
"""Tests for get_optional_current_user dependency"""
@pytest.mark.asyncio
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token):
async def test_get_optional_current_user_with_token(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user with a valid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
@pytest.mark.asyncio
async def test_get_optional_current_user_no_token(self, async_test_db):
"""Test getting optional user with no token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Call the dependency with no token
user = await get_optional_current_user(db=session, token=None)
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token):
async def test_get_optional_current_user_invalid_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Call the dependency
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token):
async def test_get_optional_current_user_expired_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Call the dependency
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
async def test_get_optional_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user when user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency

View File

@@ -2,21 +2,21 @@
"""
Tests for authentication endpoints.
"""
from unittest.mock import patch
import pytest
import pytest_asyncio
from unittest.mock import patch, MagicMock
from fastapi import status
from sqlalchemy import select
from app.models.user import User
from app.schemas.users import UserCreate
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.auth.limiter.enabled', False):
with patch("app.api.routes.auth.limiter.enabled", False):
yield
@@ -32,8 +32,8 @@ class TestRegisterEndpoint:
"email": "newuser@example.com",
"password": "SecurePassword123!",
"first_name": "New",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_201_CREATED
@@ -54,8 +54,8 @@ class TestRegisterEndpoint:
"email": async_test_user.email,
"password": "SecurePassword123!",
"first_name": "Duplicate",
"last_name": "User"
}
"last_name": "User",
},
)
# Security: Returns 400 with generic message to prevent email enumeration
@@ -73,8 +73,8 @@ class TestRegisterEndpoint:
"email": "weakpass@example.com",
"password": "weak",
"first_name": "Weak",
"last_name": "Pass"
}
"last_name": "Pass",
},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -82,7 +82,7 @@ class TestRegisterEndpoint:
@pytest.mark.asyncio
async def test_register_unexpected_error(self, client):
"""Test registration with unexpected error."""
with patch('app.services.auth_service.AuthService.create_user') as mock_create:
with patch("app.services.auth_service.AuthService.create_user") as mock_create:
mock_create.side_effect = Exception("Unexpected error")
response = await client.post(
@@ -91,8 +91,8 @@ class TestRegisterEndpoint:
"email": "error@example.com",
"password": "SecurePassword123!",
"first_name": "Error",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -106,10 +106,7 @@ class TestLoginEndpoint:
"""Test successful login."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_200_OK
@@ -123,10 +120,7 @@ class TestLoginEndpoint:
"""Test login with wrong password."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "WrongPassword123"
}
json={"email": async_test_user.email, "password": "WrongPassword123"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -136,10 +130,7 @@ class TestLoginEndpoint:
"""Test login with non-existent email."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "nonexistent@example.com",
"password": "Password123!"
}
json={"email": "nonexistent@example.com", "password": "Password123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -147,20 +138,19 @@ class TestLoginEndpoint:
@pytest.mark.asyncio
async def test_login_inactive_user(self, client, async_test_user, async_test_db):
"""Test login with inactive user."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -168,15 +158,14 @@ class TestLoginEndpoint:
@pytest.mark.asyncio
async def test_login_unexpected_error(self, client, async_test_user):
"""Test login with unexpected error."""
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
with patch(
"app.services.auth_service.AuthService.authenticate_user"
) as mock_auth:
mock_auth.side_effect = Exception("Database error")
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -190,10 +179,7 @@ class TestOAuthLoginEndpoint:
"""Test successful OAuth login."""
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "TestPassword123!"
}
data={"username": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_200_OK
@@ -206,31 +192,29 @@ class TestOAuthLoginEndpoint:
"""Test OAuth login with wrong credentials."""
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "WrongPassword"
}
data={"username": async_test_user.email, "password": "WrongPassword"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.asyncio
async def test_oauth_login_inactive_user(self, client, async_test_user, async_test_db):
async def test_oauth_login_inactive_user(
self, client, async_test_user, async_test_db
):
"""Test OAuth login with inactive user."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "TestPassword123!"
}
data={"username": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -238,15 +222,17 @@ class TestOAuthLoginEndpoint:
@pytest.mark.asyncio
async def test_oauth_login_unexpected_error(self, client, async_test_user):
"""Test OAuth login with unexpected error."""
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
with patch(
"app.services.auth_service.AuthService.authenticate_user"
) as mock_auth:
mock_auth.side_effect = Exception("Unexpected error")
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "TestPassword123!"
}
"password": "TestPassword123!",
},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -261,17 +247,13 @@ class TestRefreshTokenEndpoint:
# First, login to get a refresh token
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
refresh_token = login_response.json()["refresh_token"]
# Now refresh the token
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
assert response.status_code == status.HTTP_200_OK
@@ -284,12 +266,13 @@ class TestRefreshTokenEndpoint:
"""Test refresh with expired token."""
from app.core.auth import TokenExpiredError
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
with patch(
"app.services.auth_service.AuthService.refresh_tokens"
) as mock_refresh:
mock_refresh.side_effect = TokenExpiredError("Token expired")
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "some_token"}
"/api/v1/auth/refresh", json={"refresh_token": "some_token"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -298,8 +281,7 @@ class TestRefreshTokenEndpoint:
async def test_refresh_token_invalid(self, client):
"""Test refresh with invalid token."""
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid_token"}
"/api/v1/auth/refresh", json={"refresh_token": "invalid_token"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -310,19 +292,17 @@ class TestRefreshTokenEndpoint:
# Get a valid refresh token first
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
refresh_token = login_response.json()["refresh_token"]
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
with patch(
"app.services.auth_service.AuthService.refresh_tokens"
) as mock_refresh:
mock_refresh.side_effect = Exception("Unexpected error")
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR

View File

@@ -2,8 +2,10 @@
"""
Tests for auth route exception handlers and error paths.
"""
from unittest.mock import patch
import pytest
from unittest.mock import patch, AsyncMock
from fastapi import status
@@ -11,16 +13,18 @@ class TestLoginSessionCreationFailure:
"""Test login when session creation fails."""
@pytest.mark.asyncio
async def test_login_succeeds_despite_session_creation_failure(self, client, async_test_user):
async def test_login_succeeds_despite_session_creation_failure(
self, client, async_test_user
):
"""Test that login succeeds even if session creation fails."""
# Mock session creation to fail
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session creation failed")):
with patch(
"app.api.routes.auth.session_crud.create_session",
side_effect=Exception("Session creation failed"),
):
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
# Login should still succeed, just without session record
@@ -34,15 +38,20 @@ class TestOAuthLoginSessionCreationFailure:
"""Test OAuth login when session creation fails."""
@pytest.mark.asyncio
async def test_oauth_login_succeeds_despite_session_failure(self, client, async_test_user):
async def test_oauth_login_succeeds_despite_session_failure(
self, client, async_test_user
):
"""Test OAuth login succeeds even if session creation fails."""
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session failed")):
with patch(
"app.api.routes.auth.session_crud.create_session",
side_effect=Exception("Session failed"),
):
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": "testuser@example.com",
"password": "TestPassword123!"
}
"password": "TestPassword123!",
},
)
assert response.status_code == status.HTTP_200_OK
@@ -54,23 +63,24 @@ class TestRefreshTokenSessionUpdateFailure:
"""Test refresh token when session update fails."""
@pytest.mark.asyncio
async def test_refresh_token_succeeds_despite_session_update_failure(self, client, async_test_user):
async def test_refresh_token_succeeds_despite_session_update_failure(
self, client, async_test_user
):
"""Test that token refresh succeeds even if session update fails."""
# First login to get tokens
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
tokens = response.json()
# Mock session update to fail
with patch('app.api.routes.auth.session_crud.update_refresh_token', side_effect=Exception("Update failed")):
with patch(
"app.api.routes.auth.session_crud.update_refresh_token",
side_effect=Exception("Update failed"),
):
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": tokens["refresh_token"]}
"/api/v1/auth/refresh", json={"refresh_token": tokens["refresh_token"]}
)
# Should still succeed - tokens are issued before update
@@ -83,15 +93,14 @@ class TestLogoutWithExpiredToken:
"""Test logout with expired/invalid token."""
@pytest.mark.asyncio
async def test_logout_with_invalid_token_still_succeeds(self, client, async_test_user):
async def test_logout_with_invalid_token_still_succeeds(
self, client, async_test_user
):
"""Test logout succeeds even with invalid refresh token."""
# Login first
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
access_token = response.json()["access_token"]
@@ -99,7 +108,7 @@ class TestLogoutWithExpiredToken:
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {access_token}"},
json={"refresh_token": "invalid.token.here"}
json={"refresh_token": "invalid.token.here"},
)
# Should succeed (idempotent)
@@ -116,19 +125,16 @@ class TestLogoutWithNonExistentSession:
"""Test logout succeeds even if session not found."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
tokens = response.json()
# Mock session lookup to return None
with patch('app.api.routes.auth.session_crud.get_by_jti', return_value=None):
with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None):
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]}
json={"refresh_token": tokens["refresh_token"]},
)
# Should succeed (idempotent)
@@ -139,23 +145,25 @@ class TestLogoutUnexpectedError:
"""Test logout with unexpected errors."""
@pytest.mark.asyncio
async def test_logout_with_unexpected_error_returns_success(self, client, async_test_user):
async def test_logout_with_unexpected_error_returns_success(
self, client, async_test_user
):
"""Test logout returns success even on unexpected errors."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
tokens = response.json()
# Mock to raise unexpected error
with patch('app.api.routes.auth.session_crud.get_by_jti', side_effect=Exception("Unexpected error")):
with patch(
"app.api.routes.auth.session_crud.get_by_jti",
side_effect=Exception("Unexpected error"),
):
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]}
json={"refresh_token": tokens["refresh_token"]},
)
# Should still return success (don't expose errors)
@@ -172,18 +180,18 @@ class TestLogoutAllUnexpectedError:
"""Test logout-all handles database errors."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
access_token = response.json()["access_token"]
# Mock to raise database error
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("DB error")):
with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
side_effect=Exception("DB error"),
):
response = await client.post(
"/api/v1/auth/logout-all",
headers={"Authorization": f"Bearer {access_token}"}
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -193,7 +201,9 @@ class TestPasswordResetConfirmSessionInvalidation:
"""Test password reset invalidates sessions."""
@pytest.mark.asyncio
async def test_password_reset_continues_despite_session_invalidation_failure(self, client, async_test_user):
async def test_password_reset_continues_despite_session_invalidation_failure(
self, client, async_test_user
):
"""Test password reset succeeds even if session invalidation fails."""
# Create a valid password reset token
from app.utils.security import create_password_reset_token
@@ -201,13 +211,13 @@ class TestPasswordResetConfirmSessionInvalidation:
token = create_password_reset_token(async_test_user.email)
# Mock session invalidation to fail
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("Invalidation failed")):
with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
side_effect=Exception("Invalidation failed"),
):
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewPassword123!"
}
json={"token": token, "new_password": "NewPassword123!"},
)
# Should still succeed - password was reset

View File

@@ -2,22 +2,22 @@
"""
Tests for password reset endpoints.
"""
from unittest.mock import patch
import pytest
import pytest_asyncio
from unittest.mock import patch, AsyncMock, MagicMock
from fastapi import status
from sqlalchemy import select
from app.schemas.users import PasswordResetRequest, PasswordResetConfirm
from app.utils.security import create_password_reset_token
from app.models.user import User
from app.utils.security import create_password_reset_token
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.auth.limiter.enabled', False):
with patch("app.api.routes.auth.limiter.enabled", False):
yield
@@ -27,12 +27,14 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio
async def test_password_reset_request_valid_email(self, client, async_test_user):
"""Test password reset request with valid email."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.return_value = True
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
assert response.status_code == status.HTTP_200_OK
@@ -50,10 +52,12 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio
async def test_password_reset_request_nonexistent_email(self, client):
"""Test password reset request with non-existent email."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": "nonexistent@example.com"}
json={"email": "nonexistent@example.com"},
)
# Should still return success to prevent email enumeration
@@ -65,20 +69,26 @@ class TestPasswordResetRequest:
mock_send.assert_not_called()
@pytest.mark.asyncio
async def test_password_reset_request_inactive_user(self, client, async_test_db, async_test_user):
async def test_password_reset_request_inactive_user(
self, client, async_test_db, async_test_user
):
"""Test password reset request with inactive user."""
# Deactivate user
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
# Should still return success to prevent email enumeration
@@ -93,8 +103,7 @@ class TestPasswordResetRequest:
async def test_password_reset_request_invalid_email_format(self, client):
"""Test password reset request with invalid email format."""
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": "not-an-email"}
"/api/v1/auth/password-reset/request", json={"email": "not-an-email"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -102,22 +111,23 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio
async def test_password_reset_request_missing_email(self, client):
"""Test password reset request without email."""
response = await client.post(
"/api/v1/auth/password-reset/request",
json={}
)
response = await client.post("/api/v1/auth/password-reset/request", json={})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_password_reset_request_email_service_error(self, client, async_test_user):
async def test_password_reset_request_email_service_error(
self, client, async_test_user
):
"""Test password reset when email service fails."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.side_effect = Exception("SMTP Error")
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
# Should still return success even if email fails
@@ -128,14 +138,16 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio
async def test_password_reset_request_rate_limiting(self, client, async_test_user):
"""Test that password reset requests are rate limited."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.return_value = True
# Make multiple requests quickly (3/minute limit)
for _ in range(3):
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
assert response.status_code == status.HTTP_200_OK
@@ -144,7 +156,9 @@ class TestPasswordResetConfirm:
"""Tests for POST /auth/password-reset/confirm endpoint."""
@pytest.mark.asyncio
async def test_password_reset_confirm_valid_token(self, client, async_test_user, async_test_db):
async def test_password_reset_confirm_valid_token(
self, client, async_test_user, async_test_db
):
"""Test password reset confirmation with valid token."""
# Generate valid token
token = create_password_reset_token(async_test_user.email)
@@ -152,10 +166,7 @@ class TestPasswordResetConfirm:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": new_password
}
json={"token": token, "new_password": new_password},
)
assert response.status_code == status.HTTP_200_OK
@@ -164,11 +175,14 @@ class TestPasswordResetConfirm:
assert "successfully" in data["message"].lower()
# Verify user can login with new password
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
updated_user = result.scalar_one_or_none()
from app.core.auth import verify_password
assert verify_password(new_password, updated_user.password_hash) is True
@pytest.mark.asyncio
@@ -184,10 +198,7 @@ class TestPasswordResetConfirm:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123!"
}
json={"token": token, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -202,10 +213,7 @@ class TestPasswordResetConfirm:
"""Test password reset confirmation with invalid token."""
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": "invalid_token_xyz",
"new_password": "NewSecure123!"
}
json={"token": "invalid_token_xyz", "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -222,19 +230,18 @@ class TestPasswordResetConfirm:
# Create valid token and tamper with it
token = create_password_reset_token(async_test_user.email)
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com"
# Re-encode tampered token
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
tampered = base64.urlsafe_b64encode(
json.dumps(token_data).encode("utf-8")
).decode("utf-8")
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": tampered,
"new_password": "NewSecure123!"
}
json={"token": tampered, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -247,10 +254,7 @@ class TestPasswordResetConfirm:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123!"
}
json={"token": token, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -260,12 +264,16 @@ class TestPasswordResetConfirm:
assert "not found" in error_msg
@pytest.mark.asyncio
async def test_password_reset_confirm_inactive_user(self, client, async_test_user, async_test_db):
async def test_password_reset_confirm_inactive_user(
self, client, async_test_user, async_test_db
):
"""Test password reset confirmation for inactive user."""
# Deactivate user
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
@@ -274,10 +282,7 @@ class TestPasswordResetConfirm:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123!"
}
json={"token": token, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -301,10 +306,7 @@ class TestPasswordResetConfirm:
for weak_password in weak_passwords:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": weak_password
}
json={"token": token, "new_password": weak_password},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -315,15 +317,14 @@ class TestPasswordResetConfirm:
# Missing token
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={"new_password": "NewSecure123!"}
json={"new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
# Missing password
token = create_password_reset_token("test@example.com")
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={"token": token}
"/api/v1/auth/password-reset/confirm", json={"token": token}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -333,15 +334,12 @@ class TestPasswordResetConfirm:
token = create_password_reset_token(async_test_user.email)
# Mock the database commit to raise an exception
with patch('app.api.routes.auth.user_crud.get_by_email') as mock_get:
with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get:
mock_get.side_effect = Exception("Database error")
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123!"
}
json={"token": token, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -351,18 +349,22 @@ class TestPasswordResetConfirm:
assert "error" in error_msg or "resetting" in error_msg
@pytest.mark.asyncio
async def test_password_reset_full_flow(self, client, async_test_user, async_test_db):
async def test_password_reset_full_flow(
self, client, async_test_user, async_test_db
):
"""Test complete password reset flow."""
original_password = async_test_user.password_hash
new_password = "BrandNew123!"
# Step 1: Request password reset
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.return_value = True
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
assert response.status_code == status.HTTP_200_OK
@@ -374,29 +376,24 @@ class TestPasswordResetConfirm:
# Step 2: Confirm password reset
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": reset_token,
"new_password": new_password
}
json={"token": reset_token, "new_password": new_password},
)
assert response.status_code == status.HTTP_200_OK
# Step 3: Verify old password doesn't work
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
updated_user = result.scalar_one_or_none()
from app.core.auth import verify_password
assert updated_user.password_hash != original_password
# Step 4: Verify new password works
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": new_password
}
json={"email": async_test_user.email, "password": new_password},
)
assert response.status_code == status.HTTP_200_OK

View File

@@ -8,11 +8,10 @@ Critical security tests covering:
These tests prevent real-world attack scenarios.
"""
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import create_refresh_token
from app.crud.session import session as session_crud
from app.models.user import User
@@ -30,10 +29,7 @@ class TestRevokedSessionSecurity:
@pytest.mark.asyncio
async def test_refresh_token_rejected_after_logout(
self,
client: AsyncClient,
async_test_db,
async_test_user: User
self, client: AsyncClient, async_test_db, async_test_user: User
):
"""
Test that refresh tokens are rejected after session is deactivated.
@@ -45,10 +41,10 @@ class TestRevokedSessionSecurity:
4. Attacker tries to use stolen refresh token
5. System MUST reject it (session revoked)
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Step 1: Create a session and refresh token for the user
async with SessionLocal() as session:
async with SessionLocal():
# Login to get tokens
response = await client.post(
"/api/v1/auth/login",
@@ -64,8 +60,7 @@ class TestRevokedSessionSecurity:
# Step 2: Verify refresh token works before logout
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
assert response.status_code == 200, "Refresh should work before logout"
@@ -73,14 +68,13 @@ class TestRevokedSessionSecurity:
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {access_token}"},
json={"refresh_token": refresh_token}
json={"refresh_token": refresh_token},
)
assert response.status_code == 200, "Logout should succeed"
# Step 4: Attacker tries to use stolen refresh token
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
# Step 5: System MUST reject (covers lines 261-262)
@@ -93,10 +87,7 @@ class TestRevokedSessionSecurity:
@pytest.mark.asyncio
async def test_refresh_token_rejected_for_deleted_session(
self,
client: AsyncClient,
async_test_db,
async_test_user: User
self, client: AsyncClient, async_test_db, async_test_user: User
):
"""
Test that tokens for deleted sessions are rejected.
@@ -104,7 +95,7 @@ class TestRevokedSessionSecurity:
Attack Scenario:
Admin deletes a session from database, but attacker has the token.
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Step 1: Login to create a session
response = await client.post(
@@ -120,6 +111,7 @@ class TestRevokedSessionSecurity:
# Step 2: Manually delete the session from database (simulating admin action)
from app.core.auth import decode_token
token_data = decode_token(refresh_token, verify_type="refresh")
jti = token_data.jti
@@ -132,15 +124,17 @@ class TestRevokedSessionSecurity:
# Step 3: Try to use the refresh token
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
# Should reject (session doesn't exist)
assert response.status_code == 401
data = response.json()
if "errors" in data:
assert "revoked" in data["errors"][0]["message"].lower() or "session" in data["errors"][0]["message"].lower()
assert (
"revoked" in data["errors"][0]["message"].lower()
or "session" in data["errors"][0]["message"].lower()
)
else:
assert "revoked" in data.get("detail", "").lower()
@@ -162,7 +156,7 @@ class TestSessionHijackingSecurity:
client: AsyncClient,
async_test_db,
async_test_user: User,
async_test_superuser: User
async_test_superuser: User,
):
"""
Test that users cannot logout other users' sessions.
@@ -173,7 +167,7 @@ class TestSessionHijackingSecurity:
3. User A tries to logout User B's session
4. System MUST reject (cross-user attack)
"""
test_engine, SessionLocal = async_test_db
_test_engine, _SessionLocal = async_test_db
# Step 1: User A logs in
response = await client.post(
@@ -202,8 +196,10 @@ class TestSessionHijackingSecurity:
# Step 3: User A tries to logout User B's session using User B's refresh token
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {user_a_access}"}, # User A's access token
json={"refresh_token": user_b_refresh} # But User B's refresh token
headers={
"Authorization": f"Bearer {user_a_access}"
}, # User A's access token
json={"refresh_token": user_b_refresh}, # But User B's refresh token
)
# Step 4: System MUST reject (covers lines 509-513)
@@ -217,9 +213,7 @@ class TestSessionHijackingSecurity:
@pytest.mark.asyncio
async def test_users_can_logout_their_own_sessions(
self,
client: AsyncClient,
async_test_user: User
self, client: AsyncClient, async_test_user: User
):
"""
Sanity check: Users CAN logout their own sessions.
@@ -241,6 +235,8 @@ class TestSessionHijackingSecurity:
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]}
json={"refresh_token": tokens["refresh_token"]},
)
assert response.status_code == 200, (
"Users should be able to logout their own sessions"
)
assert response.status_code == 200, "Users should be able to logout their own sessions"

View File

@@ -5,16 +5,18 @@ Tests for organization routes (user endpoints).
These test the routes in app/api/routes/organizations.py which allow
users to view and manage organizations they belong to.
"""
from unittest.mock import patch
from uuid import uuid4
import pytest
import pytest_asyncio
from fastapi import status
from uuid import uuid4
from unittest.mock import patch, AsyncMock
from app.core.auth import get_password_hash
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.core.auth import get_password_hash
from app.models.user_organization import OrganizationRole, UserOrganization
@pytest_asyncio.fixture
@@ -22,10 +24,7 @@ async def user_token(client, async_test_user):
"""Get access token for regular user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -34,7 +33,7 @@ async def user_token(client, async_test_user):
@pytest_asyncio.fixture
async def second_user(async_test_db):
"""Create a second test user."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid4(),
@@ -56,12 +55,12 @@ async def second_user(async_test_db):
@pytest_asyncio.fixture
async def test_org_with_user_member(async_test_db, async_test_user):
"""Create a test organization with async_test_user as a member."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Member Org",
slug="member-org",
description="Test organization where user is a member"
description="Test organization where user is a member",
)
session.add(org)
await session.commit()
@@ -72,7 +71,7 @@ async def test_org_with_user_member(async_test_db, async_test_user):
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.MEMBER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -83,12 +82,12 @@ async def test_org_with_user_member(async_test_db, async_test_user):
@pytest_asyncio.fixture
async def test_org_with_user_admin(async_test_db, async_test_user):
"""Create a test organization with async_test_user as an admin."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Admin Org",
slug="admin-org",
description="Test organization where user is an admin"
description="Test organization where user is an admin",
)
session.add(org)
await session.commit()
@@ -99,7 +98,7 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.ADMIN,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -110,12 +109,12 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
@pytest_asyncio.fixture
async def test_org_with_user_owner(async_test_db, async_test_user):
"""Create a test organization with async_test_user as owner."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Owner Org",
slug="owner-org",
description="Test organization where user is owner"
description="Test organization where user is owner",
)
session.add(org)
await session.commit()
@@ -126,7 +125,7 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.OWNER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -136,21 +135,18 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
# ===== GET /api/v1/organizations/me =====
class TestGetMyOrganizations:
"""Tests for GET /api/v1/organizations/me endpoint."""
@pytest.mark.asyncio
async def test_get_my_organizations_success(
self,
client,
user_token,
test_org_with_user_member,
test_org_with_user_admin
self, client, user_token, test_org_with_user_member, test_org_with_user_admin
):
"""Test successfully getting user's organizations (covers lines 54-79)."""
response = await client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -167,21 +163,15 @@ class TestGetMyOrganizations:
@pytest.mark.asyncio
async def test_get_my_organizations_filter_active(
self,
client,
async_test_db,
async_test_user,
user_token
self, client, async_test_db, async_test_user, user_token
):
"""Test filtering organizations by active status."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create active org
async with AsyncTestingSessionLocal() as session:
active_org = Organization(
name="Active Org",
slug="active-org-filter",
is_active=True
name="Active Org", slug="active-org-filter", is_active=True
)
session.add(active_org)
await session.commit()
@@ -192,14 +182,14 @@ class TestGetMyOrganizations:
user_id=async_test_user.id,
organization_id=active_org.id,
role=OrganizationRole.MEMBER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
response = await client.get(
"/api/v1/organizations/me?is_active=true",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -209,7 +199,7 @@ class TestGetMyOrganizations:
@pytest.mark.asyncio
async def test_get_my_organizations_empty(self, client, async_test_db):
"""Test getting organizations when user has none."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user with no org memberships
async with AsyncTestingSessionLocal() as session:
@@ -219,7 +209,7 @@ class TestGetMyOrganizations:
password_hash=get_password_hash("TestPassword123!"),
first_name="No",
last_name="Org",
is_active=True
is_active=True,
)
session.add(user)
await session.commit()
@@ -227,13 +217,12 @@ class TestGetMyOrganizations:
# Login to get token
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "noorg@example.com", "password": "TestPassword123!"}
json={"email": "noorg@example.com", "password": "TestPassword123!"},
)
token = login_response.json()["access_token"]
response = await client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {token}"}
"/api/v1/organizations/me", headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -243,20 +232,18 @@ class TestGetMyOrganizations:
# ===== GET /api/v1/organizations/{organization_id} =====
class TestGetOrganization:
"""Tests for GET /api/v1/organizations/{organization_id} endpoint."""
@pytest.mark.asyncio
async def test_get_organization_success(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test successfully getting organization details (covers lines 103-122)."""
response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -272,7 +259,7 @@ class TestGetOrganization:
fake_org_id = uuid4()
response = await client.get(
f"/api/v1/organizations/{fake_org_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Permission dependency checks membership before endpoint logic
@@ -283,20 +270,14 @@ class TestGetOrganization:
@pytest.mark.asyncio
async def test_get_organization_not_member(
self,
client,
async_test_db,
async_test_user
self, client, async_test_db, async_test_user
):
"""Test getting organization where user is not a member fails."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create org without adding user
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Not Member Org",
slug="not-member-org"
)
org = Organization(name="Not Member Org", slug="not-member-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -305,13 +286,13 @@ class TestGetOrganization:
# Login as user
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
token = login_response.json()["access_token"]
response = await client.get(
f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {token}"}
headers={"Authorization": f"Bearer {token}"},
)
# Should fail permission check
@@ -320,6 +301,7 @@ class TestGetOrganization:
# ===== GET /api/v1/organizations/{organization_id}/members =====
class TestGetOrganizationMembers:
"""Tests for GET /api/v1/organizations/{organization_id}/members endpoint."""
@@ -331,10 +313,10 @@ class TestGetOrganizationMembers:
async_test_user,
second_user,
user_token,
test_org_with_user_member
test_org_with_user_member,
):
"""Test successfully getting organization members (covers lines 150-168)."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Add second user to org
async with AsyncTestingSessionLocal() as session:
@@ -342,14 +324,14 @@ class TestGetOrganizationMembers:
user_id=second_user.id,
organization_id=test_org_with_user_member.id,
role=OrganizationRole.MEMBER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -360,15 +342,12 @@ class TestGetOrganizationMembers:
@pytest.mark.asyncio
async def test_get_organization_members_with_pagination(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test pagination parameters."""
response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members?page=1&limit=10",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -385,10 +364,10 @@ class TestGetOrganizationMembers:
async_test_user,
second_user,
user_token,
test_org_with_user_member
test_org_with_user_member,
):
"""Test filtering members by active status."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Add second user as inactive member
async with AsyncTestingSessionLocal() as session:
@@ -396,7 +375,7 @@ class TestGetOrganizationMembers:
user_id=second_user.id,
organization_id=test_org_with_user_member.id,
role=OrganizationRole.MEMBER,
is_active=False
is_active=False,
)
session.add(membership)
await session.commit()
@@ -404,7 +383,7 @@ class TestGetOrganizationMembers:
# Filter for active only
response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members?is_active=true",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -416,31 +395,26 @@ class TestGetOrganizationMembers:
# ===== PUT /api/v1/organizations/{organization_id} =====
class TestUpdateOrganization:
"""Tests for PUT /api/v1/organizations/{organization_id} endpoint."""
@pytest.mark.asyncio
async def test_update_organization_as_admin_success(
self,
client,
async_test_user,
test_org_with_user_admin
self, client, async_test_user, test_org_with_user_admin
):
"""Test successfully updating organization as admin (covers lines 193-215)."""
# Login as admin user
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
admin_token = login_response.json()["access_token"]
response = await client.put(
f"/api/v1/organizations/{test_org_with_user_admin.id}",
json={
"name": "Updated Admin Org",
"description": "Updated description"
},
headers={"Authorization": f"Bearer {admin_token}"}
json={"name": "Updated Admin Org", "description": "Updated description"},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -450,23 +424,20 @@ class TestUpdateOrganization:
@pytest.mark.asyncio
async def test_update_organization_as_owner_success(
self,
client,
async_test_user,
test_org_with_user_owner
self, client, async_test_user, test_org_with_user_owner
):
"""Test successfully updating organization as owner."""
# Login as owner user
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
owner_token = login_response.json()["access_token"]
response = await client.put(
f"/api/v1/organizations/{test_org_with_user_owner.id}",
json={"name": "Updated Owner Org"},
headers={"Authorization": f"Bearer {owner_token}"}
headers={"Authorization": f"Bearer {owner_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -475,16 +446,13 @@ class TestUpdateOrganization:
@pytest.mark.asyncio
async def test_update_organization_as_member_fails(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test updating organization as regular member fails."""
response = await client.put(
f"/api/v1/organizations/{test_org_with_user_member.id}",
json={"name": "Should Fail"},
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Should fail permission check (need admin or owner)
@@ -492,15 +460,13 @@ class TestUpdateOrganization:
@pytest.mark.asyncio
async def test_update_organization_not_found(
self,
client,
test_org_with_user_admin
self, client, test_org_with_user_admin
):
"""Test updating nonexistent organization returns 403 (permission check first)."""
# Login as admin
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
admin_token = login_response.json()["access_token"]
@@ -508,7 +474,7 @@ class TestUpdateOrganization:
response = await client.put(
f"/api/v1/organizations/{fake_org_id}",
json={"name": "Updated"},
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}"},
)
# Permission dependency checks admin role before endpoint logic
@@ -520,6 +486,7 @@ class TestUpdateOrganization:
# ===== Authentication Tests =====
class TestOrganizationAuthentication:
"""Test authentication requirements for organization endpoints."""
@@ -548,14 +515,14 @@ class TestOrganizationAuthentication:
"""Test unauthenticated access to update fails."""
fake_id = uuid4()
response = await client.put(
f"/api/v1/organizations/{fake_id}",
json={"name": "Test"}
f"/api/v1/organizations/{fake_id}", json={"name": "Test"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
# ===== Exception Handler Tests (Database Error Scenarios) =====
class TestOrganizationExceptionHandlers:
"""
Test exception handlers in organization endpoints.
@@ -566,86 +533,74 @@ class TestOrganizationExceptionHandlers:
@pytest.mark.asyncio
async def test_get_my_organizations_database_error(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test generic exception handler in get_my_organizations (covers lines 81-83)."""
with patch(
"app.crud.organization.organization.get_user_organizations_with_details",
side_effect=Exception("Database connection lost")
side_effect=Exception("Database connection lost"),
):
# The exception handler logs and re-raises, so we expect the exception
# to propagate (which proves the handler executed)
with pytest.raises(Exception, match="Database connection lost"):
await client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
@pytest.mark.asyncio
async def test_get_organization_database_error(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test generic exception handler in get_organization (covers lines 124-128)."""
with patch(
"app.crud.organization.organization.get",
side_effect=Exception("Database timeout")
side_effect=Exception("Database timeout"),
):
with pytest.raises(Exception, match="Database timeout"):
await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
@pytest.mark.asyncio
async def test_get_organization_members_database_error(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test generic exception handler in get_organization_members (covers lines 170-172)."""
with patch(
"app.crud.organization.organization.get_organization_members",
side_effect=Exception("Connection pool exhausted")
side_effect=Exception("Connection pool exhausted"),
):
with pytest.raises(Exception, match="Connection pool exhausted"):
await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
@pytest.mark.asyncio
async def test_update_organization_database_error(
self,
client,
async_test_user,
test_org_with_user_admin
self, client, async_test_user, test_org_with_user_admin
):
"""Test generic exception handler in update_organization (covers lines 217-221)."""
# Login as admin user
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
admin_token = login_response.json()["access_token"]
with patch(
"app.crud.organization.organization.get",
return_value=test_org_with_user_admin
return_value=test_org_with_user_admin,
):
with patch(
"app.crud.organization.organization.update",
side_effect=Exception("Write lock timeout")
side_effect=Exception("Write lock timeout"),
):
with pytest.raises(Exception, match="Write lock timeout"):
await client.put(
f"/api/v1/organizations/{test_org_with_user_admin.id}",
json={"name": "Should Fail"},
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}"},
)

View File

@@ -5,15 +5,17 @@ Tests for permission dependencies - CRITICAL SECURITY PATHS.
These tests ensure superusers can bypass organization checks correctly,
and that regular users are properly blocked.
"""
from uuid import uuid4
import pytest
import pytest_asyncio
from fastapi import status
from uuid import uuid4
from app.core.auth import get_password_hash
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.core.auth import get_password_hash
from app.models.user_organization import OrganizationRole, UserOrganization
@pytest_asyncio.fixture
@@ -21,10 +23,7 @@ async def superuser_token(client, async_test_superuser):
"""Get access token for superuser."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "superuser@example.com",
"password": "SuperPassword123!"
}
json={"email": "superuser@example.com", "password": "SuperPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -35,10 +34,7 @@ async def regular_user_token(client, async_test_user):
"""Get access token for regular user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -47,12 +43,12 @@ async def regular_user_token(client, async_test_user):
@pytest_asyncio.fixture
async def test_org_no_members(async_test_db):
"""Create a test organization with NO members."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="No Members Org",
slug="no-members-org",
description="Test org with no members"
description="Test org with no members",
)
session.add(org)
await session.commit()
@@ -63,12 +59,12 @@ async def test_org_no_members(async_test_db):
@pytest_asyncio.fixture
async def test_org_with_member(async_test_db, async_test_user):
"""Create a test organization with async_test_user as member (not admin)."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Member Only Org",
slug="member-only-org",
description="Test org where user is just a member"
description="Test org where user is just a member",
)
session.add(org)
await session.commit()
@@ -79,7 +75,7 @@ async def test_org_with_member(async_test_db, async_test_user):
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.MEMBER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -89,6 +85,7 @@ async def test_org_with_member(async_test_db, async_test_user):
# ===== CRITICAL SECURITY TESTS: Superuser Bypass =====
class TestSuperuserBypass:
"""
CRITICAL: Test that superusers can bypass organization checks.
@@ -99,10 +96,7 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_superuser_can_access_org_not_member_of(
self,
client,
superuser_token,
test_org_no_members
self, client, superuser_token, test_org_no_members
):
"""
CRITICAL: Superuser should bypass membership check (covers line 175).
@@ -111,7 +105,7 @@ class TestSuperuserBypass:
"""
response = await client.get(
f"/api/v1/organizations/{test_org_no_members.id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Superuser should succeed even though they're not a member
@@ -121,15 +115,12 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_regular_user_cannot_access_org_not_member_of(
self,
client,
regular_user_token,
test_org_no_members
self, client, regular_user_token, test_org_no_members
):
"""Regular user should be blocked from org they're not a member of."""
response = await client.get(
f"/api/v1/organizations/{test_org_no_members.id}",
headers={"Authorization": f"Bearer {regular_user_token}"}
headers={"Authorization": f"Bearer {regular_user_token}"},
)
# Regular user should fail permission check
@@ -137,10 +128,7 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_superuser_can_update_org_not_admin_of(
self,
client,
superuser_token,
test_org_no_members
self, client, superuser_token, test_org_no_members
):
"""
CRITICAL: Superuser should bypass admin check (covers line 99).
@@ -150,7 +138,7 @@ class TestSuperuserBypass:
response = await client.put(
f"/api/v1/organizations/{test_org_no_members.id}",
json={"name": "Updated by Superuser"},
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Superuser should succeed in updating org
@@ -160,16 +148,13 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_regular_member_cannot_update_org(
self,
client,
regular_user_token,
test_org_with_member
self, client, regular_user_token, test_org_with_member
):
"""Regular member (not admin) should NOT be able to update org."""
response = await client.put(
f"/api/v1/organizations/{test_org_with_member.id}",
json={"name": "Should Fail"},
headers={"Authorization": f"Bearer {regular_user_token}"}
headers={"Authorization": f"Bearer {regular_user_token}"},
)
# Member should fail - need admin or owner role
@@ -177,15 +162,12 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_superuser_can_list_org_members_not_member_of(
self,
client,
superuser_token,
test_org_no_members
self, client, superuser_token, test_org_no_members
):
"""CRITICAL: Superuser should bypass membership check to list members."""
response = await client.get(
f"/api/v1/organizations/{test_org_no_members.id}/members",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Superuser should succeed
@@ -197,13 +179,14 @@ class TestSuperuserBypass:
# ===== Edge Cases and Security Tests =====
class TestPermissionEdgeCases:
"""Test edge cases in permission system."""
@pytest.mark.asyncio
async def test_inactive_user_blocked(self, client, async_test_db):
"""Test that inactive users are blocked."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create inactive user
async with AsyncTestingSessionLocal() as session:
@@ -213,7 +196,7 @@ class TestPermissionEdgeCases:
password_hash=get_password_hash("TestPassword123!"),
first_name="Inactive",
last_name="User",
is_active=False # INACTIVE
is_active=False, # INACTIVE
)
session.add(user)
await session.commit()
@@ -222,7 +205,7 @@ class TestPermissionEdgeCases:
# But accessing protected endpoints should fail
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "inactive@example.com", "password": "TestPassword123!"}
json={"email": "inactive@example.com", "password": "TestPassword123!"},
)
# Login might fail for inactive users depending on auth implementation
@@ -231,18 +214,18 @@ class TestPermissionEdgeCases:
# Try to access protected endpoint
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {token}"}
"/api/v1/users/me", headers={"Authorization": f"Bearer {token}"}
)
# Should be blocked
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
assert response.status_code in [
status.HTTP_401_UNAUTHORIZED,
status.HTTP_403_FORBIDDEN,
]
@pytest.mark.asyncio
async def test_nonexistent_organization_returns_403_not_404(
self,
client,
regular_user_token
self, client, regular_user_token
):
"""
Test that accessing nonexistent org returns 403, not 404.
@@ -254,7 +237,7 @@ class TestPermissionEdgeCases:
fake_org_id = uuid4()
response = await client.get(
f"/api/v1/organizations/{fake_org_id}",
headers={"Authorization": f"Bearer {regular_user_token}"}
headers={"Authorization": f"Bearer {regular_user_token}"},
)
# Should get 403 (not a member), not 404 (doesn't exist)
@@ -264,18 +247,16 @@ class TestPermissionEdgeCases:
# ===== Admin Role Tests =====
class TestAdminRolePermissions:
"""Test admin role can perform admin actions."""
@pytest_asyncio.fixture
async def test_org_with_admin(self, async_test_db, async_test_user):
"""Create org where user is ADMIN."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Admin Org",
slug="admin-org"
)
org = Organization(name="Admin Org", slug="admin-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -284,7 +265,7 @@ class TestAdminRolePermissions:
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.ADMIN,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -293,16 +274,13 @@ class TestAdminRolePermissions:
@pytest.mark.asyncio
async def test_admin_can_update_org(
self,
client,
regular_user_token,
test_org_with_admin
self, client, regular_user_token, test_org_with_admin
):
"""Admin should be able to update organization."""
response = await client.put(
f"/api/v1/organizations/{test_org_with_admin.id}",
json={"name": "Updated by Admin"},
headers={"Authorization": f"Bearer {regular_user_token}"}
headers={"Authorization": f"Bearer {regular_user_token}"},
)
assert response.status_code == status.HTTP_200_OK

View File

@@ -7,13 +7,13 @@ Critical security tests covering:
These tests prevent unauthorized access and privilege escalation.
"""
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.user import User
from app.models.organization import Organization
from app.crud.user import user as user_crud
from app.models.organization import Organization
from app.models.user import User
class TestInactiveUserBlocking:
@@ -29,11 +29,7 @@ class TestInactiveUserBlocking:
@pytest.mark.asyncio
async def test_inactive_user_cannot_access_protected_endpoints(
self,
client: AsyncClient,
async_test_db,
async_test_user: User,
user_token: str
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
):
"""
Test that inactive users are blocked from protected endpoints.
@@ -44,12 +40,11 @@ class TestInactiveUserBlocking:
3. User tries to access protected endpoint with valid token
4. System MUST reject (account inactive)
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Step 1: Verify user can access endpoint while active
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == 200, "Active user should have access"
@@ -61,8 +56,7 @@ class TestInactiveUserBlocking:
# Step 3: User tries to access endpoint with same token
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
)
# Step 4: System MUST reject (covers lines 52-57)
@@ -75,18 +69,14 @@ class TestInactiveUserBlocking:
@pytest.mark.asyncio
async def test_inactive_user_blocked_from_organization_endpoints(
self,
client: AsyncClient,
async_test_db,
async_test_user: User,
user_token: str
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
):
"""
Test that inactive users can't access organization endpoints.
Ensures the inactive check applies to ALL protected endpoints.
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Deactivate user
async with SessionLocal() as session:
@@ -97,7 +87,7 @@ class TestInactiveUserBlocking:
# Try to list organizations
response = await client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Must be blocked
@@ -122,7 +112,7 @@ class TestSuperuserPrivilegeEscalation:
client: AsyncClient,
async_test_db,
async_test_superuser: User,
superuser_token: str
superuser_token: str,
):
"""
Test that superusers automatically get OWNER role in organizations.
@@ -131,14 +121,11 @@ class TestSuperuserPrivilegeEscalation:
Superusers can manage any organization without being explicitly added.
This is for platform administration.
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Step 1: Create an organization (owned by someone else)
async with SessionLocal() as session:
org = Organization(
name="Test Organization",
slug="test-org"
)
org = Organization(name="Test Organization", slug="test-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -148,7 +135,7 @@ class TestSuperuserPrivilegeEscalation:
# (They're not a member, but should auto-get OWNER role)
response = await client.get(
f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Step 3: Should have access (covers lines 154-157)
@@ -161,21 +148,18 @@ class TestSuperuserPrivilegeEscalation:
client: AsyncClient,
async_test_db,
async_test_superuser: User,
superuser_token: str
superuser_token: str,
):
"""
Test that superusers have full management access to all organizations.
Ensures the OWNER role privilege escalation works end-to-end.
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization
async with SessionLocal() as session:
org = Organization(
name="Test Organization",
slug="test-org"
)
org = Organization(name="Test Organization", slug="test-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -185,34 +169,29 @@ class TestSuperuserPrivilegeEscalation:
response = await client.put(
f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"name": "Updated Name"}
json={"name": "Updated Name"},
)
# Should succeed (superuser has OWNER privileges)
assert response.status_code in [200, 404], "Superuser should be able to manage any org"
assert response.status_code in [200, 404], (
"Superuser should be able to manage any org"
)
# Note: Might be 404 if org endpoints require membership, but the role check passes
@pytest.mark.asyncio
async def test_regular_user_does_not_get_owner_role(
self,
client: AsyncClient,
async_test_db,
async_test_user: User,
user_token: str
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
):
"""
Sanity check: Regular users don't get automatic OWNER role.
Ensures the superuser check is working correctly (line 154).
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization
async with SessionLocal() as session:
org = Organization(
name="Test Organization",
slug="test-org"
)
org = Organization(name="Test Organization", slug="test-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -221,8 +200,10 @@ class TestSuperuserPrivilegeEscalation:
# Regular user tries to access it (not a member)
response = await client.get(
f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Should be denied (not a member, not a superuser)
assert response.status_code in [403, 404], "Regular user shouldn't access non-member org"
assert response.status_code in [403, 404], (
"Regular user shouldn't access non-member org"
)

View File

@@ -1,7 +1,8 @@
# tests/api/test_security_headers.py
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch
from app.main import app
@@ -11,8 +12,10 @@ def client():
"""Create a FastAPI test client for the main app (module-scoped for speed)."""
# Mock get_db to avoid database connection issues
with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None)
@@ -77,8 +80,10 @@ class TestSecurityHeaders:
"""Test that HSTS header is set in production (covers line 95)"""
with patch("app.core.config.settings.ENVIRONMENT", "production"):
with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None)
@@ -88,20 +93,26 @@ class TestSecurityHeaders:
# Need to reimport app to pick up the new settings
from importlib import reload
import app.main
reload(app.main)
test_client = TestClient(app.main.app)
response = test_client.get("/health")
assert "Strict-Transport-Security" in response.headers
assert "max-age=31536000" in response.headers["Strict-Transport-Security"]
assert (
"max-age=31536000" in response.headers["Strict-Transport-Security"]
)
def test_csp_strict_mode(self):
"""Test CSP strict mode (covers line 121)"""
with patch("app.core.config.settings.CSP_MODE", "strict"):
with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None)
@@ -110,7 +121,9 @@ class TestSecurityHeaders:
mock_get_db.side_effect = lambda: mock_session_generator()
from importlib import reload
import app.main
reload(app.main)
test_client = TestClient(app.main.app)
@@ -136,8 +149,10 @@ class TestRootEndpoint:
def test_root_endpoint(self):
"""Test root endpoint returns HTML (covers line 174)"""
with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None)

View File

@@ -2,23 +2,23 @@
"""
Comprehensive tests for session management API endpoints.
"""
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
from uuid import uuid4
import pytest
import pytest_asyncio
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from unittest.mock import patch
from fastapi import status
from app.models.user_session import UserSession
from app.schemas.users import UserCreate
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.sessions.limiter.enabled', False):
with patch("app.api.routes.sessions.limiter.enabled", False):
yield
@@ -27,10 +27,7 @@ async def user_token(client, async_test_user):
"""Create and return an access token for async_test_user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -39,7 +36,7 @@ async def user_token(client, async_test_user):
@pytest_asyncio.fixture
async def async_test_user2(async_test_db):
"""Create a second test user."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
from app.crud.user import user as user_crud
@@ -49,7 +46,7 @@ async def async_test_user2(async_test_db):
email="testuser2@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User2"
last_name="User2",
)
user = await user_crud.create(session, obj_in=user_data)
await session.commit()
@@ -61,9 +58,11 @@ class TestListMySessions:
"""Tests for GET /api/v1/sessions/me endpoint."""
@pytest.mark.asyncio
async def test_list_my_sessions_success(self, client, async_test_user, async_test_db, user_token):
async def test_list_my_sessions_success(
self, client, async_test_user, async_test_db, user_token
):
"""Test successfully listing user's active sessions."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create some sessions for the user
async with SessionLocal() as session:
@@ -75,8 +74,8 @@ class TestListMySessions:
ip_address="192.168.1.100",
user_agent="Mozilla/5.0 (iPhone)",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
# Active session 2
s2 = UserSession(
@@ -86,8 +85,8 @@ class TestListMySessions:
ip_address="192.168.1.101",
user_agent="Mozilla/5.0 (Macintosh)",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
# Inactive session (should not appear)
s3 = UserSession(
@@ -97,16 +96,15 @@ class TestListMySessions:
ip_address="192.168.1.102",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(days=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(days=1),
)
session.add_all([s1, s2, s3])
await session.commit()
# Make request
response = await client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -128,11 +126,12 @@ class TestListMySessions:
assert data["sessions"][0]["is_current"] is True
@pytest.mark.asyncio
async def test_list_my_sessions_with_login_session(self, client, async_test_user, user_token):
async def test_list_my_sessions_with_login_session(
self, client, async_test_user, user_token
):
"""Test listing sessions shows the login session."""
response = await client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -155,9 +154,11 @@ class TestRevokeSession:
"""Tests for DELETE /api/v1/sessions/{session_id} endpoint."""
@pytest.mark.asyncio
async def test_revoke_session_success(self, client, async_test_user, async_test_db, user_token):
async def test_revoke_session_success(
self, client, async_test_user, async_test_db, user_token
):
"""Test successfully revoking a session."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session to revoke
async with SessionLocal() as session:
@@ -168,8 +169,8 @@ class TestRevokeSession:
ip_address="192.168.1.103",
user_agent="Mozilla/5.0 (iPad)",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -179,7 +180,7 @@ class TestRevokeSession:
# Revoke the session
response = await client.delete(
f"/api/v1/sessions/{session_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -191,6 +192,7 @@ class TestRevokeSession:
# Verify session is deactivated
async with SessionLocal() as session:
from app.crud.session import session as session_crud
revoked_session = await session_crud.get(session, id=str(session_id))
assert revoked_session.is_active is False
@@ -200,7 +202,7 @@ class TestRevokeSession:
fake_id = uuid4()
response = await client.delete(
f"/api/v1/sessions/{fake_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -222,7 +224,7 @@ class TestRevokeSession:
self, client, async_test_user, async_test_user2, async_test_db, user_token
):
"""Test that users cannot revoke other users' sessions."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session for user2
async with SessionLocal() as session:
@@ -233,8 +235,8 @@ class TestRevokeSession:
ip_address="192.168.1.200",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(other_user_session)
await session.commit()
@@ -244,7 +246,7 @@ class TestRevokeSession:
# Try to revoke it as user1
response = await client.delete(
f"/api/v1/sessions/{session_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -263,7 +265,7 @@ class TestCleanupExpiredSessions:
self, client, async_test_user, async_test_db, user_token
):
"""Test successfully cleaning up expired sessions."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create expired and active sessions using CRUD to avoid greenlet issues
from app.crud.session import session as session_crud
@@ -277,8 +279,8 @@ class TestCleanupExpiredSessions:
device_name="Expired 1",
ip_address="192.168.1.201",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
e1 = await session_crud.create_session(db, obj_in=e1_data)
e1.is_active = False
@@ -291,8 +293,8 @@ class TestCleanupExpiredSessions:
device_name="Expired 2",
ip_address="192.168.1.202",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2),
)
e2 = await session_crud.create_session(db, obj_in=e2_data)
e2.is_active = False
@@ -305,8 +307,8 @@ class TestCleanupExpiredSessions:
device_name="Active",
ip_address="192.168.1.203",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
await session_crud.create_session(db, obj_in=a1_data)
await db.commit()
@@ -314,7 +316,7 @@ class TestCleanupExpiredSessions:
# Cleanup expired sessions
response = await client.delete(
"/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -329,7 +331,7 @@ class TestCleanupExpiredSessions:
self, client, async_test_user, async_test_db, user_token
):
"""Test cleanup when no sessions are expired."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create only active sessions using CRUD
from app.crud.session import session as session_crud
@@ -342,15 +344,15 @@ class TestCleanupExpiredSessions:
device_name="Active Device",
ip_address="192.168.1.210",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
await session_crud.create_session(db, obj_in=a1_data)
await db.commit()
response = await client.delete(
"/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -369,13 +371,16 @@ class TestCleanupExpiredSessions:
# Additional tests for better coverage
class TestSessionsAdditionalCases:
"""Additional tests to improve sessions endpoint coverage."""
@pytest.mark.asyncio
async def test_list_sessions_pagination(self, client, async_test_user, async_test_db, user_token):
async def test_list_sessions_pagination(
self, client, async_test_user, async_test_db, user_token
):
"""Test listing sessions with pagination."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create multiple sessions
async with SessionLocal() as session:
@@ -389,15 +394,15 @@ class TestSessionsAdditionalCases:
device_name=f"Device {i}",
ip_address=f"192.168.1.{i}",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
await session_crud.create_session(session, obj_in=session_data)
await session.commit()
response = await client.get(
"/api/v1/sessions/me?page=1&limit=3",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -410,16 +415,21 @@ class TestSessionsAdditionalCases:
"""Test revoking session with invalid UUID."""
response = await client.delete(
"/api/v1/sessions/not-a-uuid",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Should return 422 for invalid UUID format
assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_404_NOT_FOUND]
assert response.status_code in [
status.HTTP_422_UNPROCESSABLE_ENTITY,
status.HTTP_404_NOT_FOUND,
]
@pytest.mark.asyncio
async def test_cleanup_expired_sessions_with_mixed_states(self, client, async_test_user, async_test_db, user_token):
async def test_cleanup_expired_sessions_with_mixed_states(
self, client, async_test_user, async_test_db, user_token
):
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
from app.crud.session import session as session_crud
from app.schemas.sessions import SessionCreate
@@ -432,8 +442,8 @@ class TestSessionsAdditionalCases:
device_name="Expired Inactive",
ip_address="192.168.1.100",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
e1 = await session_crud.create_session(db, obj_in=e1_data)
e1.is_active = False
@@ -446,8 +456,8 @@ class TestSessionsAdditionalCases:
device_name="Expired Active",
ip_address="192.168.1.101",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2),
)
await session_crud.create_session(db, obj_in=e2_data)
@@ -455,7 +465,7 @@ class TestSessionsAdditionalCases:
response = await client.delete(
"/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -476,10 +486,12 @@ class TestSessionExceptionHandlers:
from unittest.mock import patch
# Patch decode_token to raise an exception
with patch('app.api.routes.sessions.decode_token', side_effect=Exception("Token decode error")):
with patch(
"app.api.routes.sessions.decode_token",
side_effect=Exception("Token decode error"),
):
response = await client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
)
# Should still succeed (exception is caught and ignored in try/except at line 77)
@@ -489,12 +501,16 @@ class TestSessionExceptionHandlers:
async def test_list_sessions_database_error(self, client, user_token):
"""Test list_sessions handles database errors (covers lines 104-106)."""
from unittest.mock import patch
from app.crud import session as session_module
with patch.object(session_module.session, 'get_user_sessions', side_effect=Exception("Database error")):
with patch.object(
session_module.session,
"get_user_sessions",
side_effect=Exception("Database error"),
):
response = await client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -503,18 +519,21 @@ class TestSessionExceptionHandlers:
assert data["errors"][0]["message"] == "Failed to retrieve sessions"
@pytest.mark.asyncio
async def test_revoke_session_database_error(self, client, user_token, async_test_db, async_test_user):
async def test_revoke_session_database_error(
self, client, user_token, async_test_db, async_test_user
):
"""Test revoke_session handles database errors (covers lines 181-183)."""
from datetime import datetime, timedelta
from unittest.mock import patch
from uuid import uuid4
from app.crud import session as session_module
# First create a session to revoke
from app.crud.session import session as session_crud
from app.schemas.sessions import SessionCreate
from datetime import datetime, timedelta, timezone
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as db:
session_in = SessionCreate(
@@ -523,17 +542,21 @@ class TestSessionExceptionHandlers:
device_name="Test Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60),
)
user_session = await session_crud.create_session(db, obj_in=session_in)
session_id = user_session.id
# Mock the deactivate method to raise an exception
with patch.object(session_module.session, 'deactivate', side_effect=Exception("Database connection lost")):
with patch.object(
session_module.session,
"deactivate",
side_effect=Exception("Database connection lost"),
):
response = await client.delete(
f"/api/v1/sessions/{session_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -544,12 +567,17 @@ class TestSessionExceptionHandlers:
async def test_cleanup_expired_sessions_database_error(self, client, user_token):
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
from unittest.mock import patch
from app.crud import session as session_module
with patch.object(session_module.session, 'cleanup_expired_for_user', side_effect=Exception("Cleanup failed")):
with patch.object(
session_module.session,
"cleanup_expired_for_user",
side_effect=Exception("Cleanup failed"),
):
response = await client.delete(
"/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR

View File

@@ -3,32 +3,29 @@
Comprehensive tests for user management endpoints.
These tests focus on finding potential bugs, not just coverage.
"""
import pytest
import pytest_asyncio
from unittest.mock import patch
from fastapi import status
import uuid
from sqlalchemy import select
import uuid
from unittest.mock import patch
import pytest
from fastapi import status
from app.models.user import User
from app.models.user import User
from app.schemas.users import UserUpdate
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.users.limiter.enabled', False):
with patch('app.api.routes.auth.limiter.enabled', False):
with patch("app.api.routes.users.limiter.enabled", False):
with patch("app.api.routes.auth.limiter.enabled", False):
yield
async def get_auth_headers(client, email, password):
"""Helper to get authentication headers."""
response = await client.post(
"/api/v1/auth/login",
json={"email": email, "password": password}
"/api/v1/auth/login", json={"email": email, "password": password}
)
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
@@ -40,7 +37,9 @@ class TestListUsers:
@pytest.mark.asyncio
async def test_list_users_as_superuser(self, client, async_test_superuser):
"""Test listing users as superuser."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get("/api/v1/users", headers=headers)
@@ -53,16 +52,20 @@ class TestListUsers:
@pytest.mark.asyncio
async def test_list_users_as_regular_user(self, client, async_test_user):
"""Test that regular users cannot list users."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get("/api/v1/users", headers=headers)
assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_list_users_pagination(self, client, async_test_superuser, async_test_db):
async def test_list_users_pagination(
self, client, async_test_superuser, async_test_db
):
"""Test pagination works correctly."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -72,12 +75,14 @@ class TestListUsers:
password_hash="hash",
first_name=f"PagUser{i}",
is_active=True,
is_superuser=False
is_superuser=False,
)
session.add(user)
await session.commit()
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
# Get first page
response = await client.get("/api/v1/users?page=1&limit=5", headers=headers)
@@ -88,9 +93,11 @@ class TestListUsers:
assert data["pagination"]["total"] >= 15
@pytest.mark.asyncio
async def test_list_users_filter_active(self, client, async_test_superuser, async_test_db):
async def test_list_users_filter_active(
self, client, async_test_superuser, async_test_db
):
"""Test filtering by active status."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create active and inactive users
async with AsyncTestingSessionLocal() as session:
@@ -99,19 +106,21 @@ class TestListUsers:
password_hash="hash",
first_name="Active",
is_active=True,
is_superuser=False
is_superuser=False,
)
inactive_user = User(
email="inactivefilter@example.com",
password_hash="hash",
first_name="Inactive",
is_active=False,
is_superuser=False
is_superuser=False,
)
session.add_all([active_user, inactive_user])
await session.commit()
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
# Filter for active users
response = await client.get("/api/v1/users?is_active=true", headers=headers)
@@ -130,9 +139,13 @@ class TestListUsers:
@pytest.mark.asyncio
async def test_list_users_sort_by_email(self, client, async_test_superuser):
"""Test sorting users by email."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get("/api/v1/users?sort_by=email&sort_order=asc", headers=headers)
response = await client.get(
"/api/v1/users?sort_by=email&sort_order=asc", headers=headers
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
emails = [u["email"] for u in data["data"]]
@@ -154,7 +167,9 @@ class TestGetCurrentUserProfile:
@pytest.mark.asyncio
async def test_get_own_profile(self, client, async_test_user):
"""Test getting own profile."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get("/api/v1/users/me", headers=headers)
@@ -176,12 +191,14 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_own_profile(self, client, async_test_user):
"""Test updating own profile."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me",
headers=headers,
json={"first_name": "Updated", "last_name": "Name"}
json={"first_name": "Updated", "last_name": "Name"},
)
assert response.status_code == status.HTTP_200_OK
@@ -192,12 +209,12 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_profile_phone_number(self, client, async_test_user, test_db):
"""Test updating phone number with validation."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me",
headers=headers,
json={"phone_number": "+19876543210"}
"/api/v1/users/me", headers=headers, json={"phone_number": "+19876543210"}
)
assert response.status_code == status.HTTP_200_OK
@@ -207,12 +224,12 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_profile_invalid_phone(self, client, async_test_user):
"""Test that invalid phone numbers are rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me",
headers=headers,
json={"phone_number": "invalid"}
"/api/v1/users/me", headers=headers, json={"phone_number": "invalid"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -220,14 +237,16 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_cannot_elevate_to_superuser(self, client, async_test_user):
"""Test that users cannot make themselves superuser."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
# Note: is_superuser is now in UserUpdate schema with explicit validation
# This tests that Pydantic rejects the attempt at the schema level
response = await client.patch(
"/api/v1/users/me",
headers=headers,
json={"first_name": "Test", "is_superuser": True}
json={"first_name": "Test", "is_superuser": True},
)
# Pydantic validation should reject this at the schema level
@@ -242,10 +261,7 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_profile_no_auth(self, client):
"""Test that unauthenticated requests are rejected."""
response = await client.patch(
"/api/v1/users/me",
json={"first_name": "Hacker"}
)
response = await client.patch("/api/v1/users/me", json={"first_name": "Hacker"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
# Note: Removed test_update_profile_unexpected_error - see comment above
@@ -257,16 +273,22 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_own_profile_by_id(self, client, async_test_user):
"""Test getting own profile by ID."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
response = await client.get(
f"/api/v1/users/{async_test_user.id}", headers=headers
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["email"] == async_test_user.email
@pytest.mark.asyncio
async def test_get_other_user_as_regular_user(self, client, async_test_user, test_db):
async def test_get_other_user_as_regular_user(
self, client, async_test_user, test_db
):
"""Test that regular users cannot view other profiles."""
# Create another user
other_user = User(
@@ -274,24 +296,32 @@ class TestGetUserById:
password_hash="hash",
first_name="Other",
is_active=True,
is_superuser=False
is_superuser=False,
)
test_db.add(other_user)
test_db.commit()
test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get(f"/api/v1/users/{other_user.id}", headers=headers)
assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_get_other_user_as_superuser(self, client, async_test_superuser, async_test_user):
async def test_get_other_user_as_superuser(
self, client, async_test_superuser, async_test_user
):
"""Test that superusers can view other profiles."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
response = await client.get(
f"/api/v1/users/{async_test_user.id}", headers=headers
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
@@ -300,7 +330,9 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_nonexistent_user(self, client, async_test_superuser):
"""Test getting non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
fake_id = uuid.uuid4()
response = await client.get(f"/api/v1/users/{fake_id}", headers=headers)
@@ -310,7 +342,9 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_user_invalid_uuid(self, client, async_test_superuser):
"""Test getting user with invalid UUID format."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get("/api/v1/users/not-a-uuid", headers=headers)
@@ -323,12 +357,14 @@ class TestUpdateUserById:
@pytest.mark.asyncio
async def test_update_own_profile_by_id(self, client, async_test_user, test_db):
"""Test updating own profile by ID."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers=headers,
json={"first_name": "SelfUpdated"}
json={"first_name": "SelfUpdated"},
)
assert response.status_code == status.HTTP_200_OK
@@ -336,7 +372,9 @@ class TestUpdateUserById:
assert data["first_name"] == "SelfUpdated"
@pytest.mark.asyncio
async def test_update_other_user_as_regular_user(self, client, async_test_user, test_db):
async def test_update_other_user_as_regular_user(
self, client, async_test_user, test_db
):
"""Test that regular users cannot update other profiles."""
# Create another user
other_user = User(
@@ -344,18 +382,20 @@ class TestUpdateUserById:
password_hash="hash",
first_name="Other",
is_active=True,
is_superuser=False
is_superuser=False,
)
test_db.add(other_user)
test_db.commit()
test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
f"/api/v1/users/{other_user.id}",
headers=headers,
json={"first_name": "Hacked"}
json={"first_name": "Hacked"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -365,14 +405,18 @@ class TestUpdateUserById:
assert other_user.first_name == "Other"
@pytest.mark.asyncio
async def test_update_other_user_as_superuser(self, client, async_test_superuser, async_test_user, test_db):
async def test_update_other_user_as_superuser(
self, client, async_test_superuser, async_test_user, test_db
):
"""Test that superusers can update other profiles."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers=headers,
json={"first_name": "AdminUpdated"}
json={"first_name": "AdminUpdated"},
)
assert response.status_code == status.HTTP_200_OK
@@ -380,16 +424,20 @@ class TestUpdateUserById:
assert data["first_name"] == "AdminUpdated"
@pytest.mark.asyncio
async def test_regular_user_cannot_modify_superuser_status(self, client, async_test_user):
async def test_regular_user_cannot_modify_superuser_status(
self, client, async_test_user
):
"""Test that regular users cannot change superuser status even if they try."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
# is_superuser not in UserUpdate schema, so it gets ignored by Pydantic
# Just verify the user stays the same
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers=headers,
json={"first_name": "Test"}
json={"first_name": "Test"},
)
assert response.status_code == status.HTTP_200_OK
@@ -397,14 +445,18 @@ class TestUpdateUserById:
assert data["is_superuser"] is False
@pytest.mark.asyncio
async def test_superuser_can_update_users(self, client, async_test_superuser, async_test_user, test_db):
async def test_superuser_can_update_users(
self, client, async_test_superuser, async_test_user, test_db
):
"""Test that superusers can update other users."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers=headers,
json={"first_name": "AdminChanged", "is_active": False}
json={"first_name": "AdminChanged", "is_active": False},
)
assert response.status_code == status.HTTP_200_OK
@@ -415,13 +467,13 @@ class TestUpdateUserById:
@pytest.mark.asyncio
async def test_update_nonexistent_user(self, client, async_test_superuser):
"""Test updating non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
fake_id = uuid.uuid4()
response = await client.patch(
f"/api/v1/users/{fake_id}",
headers=headers,
json={"first_name": "Ghost"}
f"/api/v1/users/{fake_id}", headers=headers, json={"first_name": "Ghost"}
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -435,15 +487,17 @@ class TestChangePassword:
@pytest.mark.asyncio
async def test_change_password_success(self, client, async_test_user, test_db):
"""Test successful password change."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me/password",
headers=headers,
json={
"current_password": "TestPassword123!",
"new_password": "NewPassword123!"
}
"new_password": "NewPassword123!",
},
)
assert response.status_code == status.HTTP_200_OK
@@ -453,25 +507,24 @@ class TestChangePassword:
# Verify can login with new password
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "NewPassword123!"
}
json={"email": async_test_user.email, "password": "NewPassword123!"},
)
assert login_response.status_code == status.HTTP_200_OK
@pytest.mark.asyncio
async def test_change_password_wrong_current(self, client, async_test_user):
"""Test that wrong current password is rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me/password",
headers=headers,
json={
"current_password": "WrongPassword123",
"new_password": "NewPassword123!"
}
"new_password": "NewPassword123!",
},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -479,15 +532,14 @@ class TestChangePassword:
@pytest.mark.asyncio
async def test_change_password_weak_new_password(self, client, async_test_user):
"""Test that weak new passwords are rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me/password",
headers=headers,
json={
"current_password": "TestPassword123!",
"new_password": "weak"
}
json={"current_password": "TestPassword123!", "new_password": "weak"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -499,8 +551,8 @@ class TestChangePassword:
"/api/v1/users/me/password",
json={
"current_password": "TestPassword123!",
"new_password": "NewPassword123!"
}
"new_password": "NewPassword123!",
},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -511,9 +563,11 @@ class TestDeleteUser:
"""Tests for DELETE /users/{user_id} endpoint."""
@pytest.mark.asyncio
async def test_delete_user_as_superuser(self, client, async_test_superuser, async_test_db):
async def test_delete_user_as_superuser(
self, client, async_test_superuser, async_test_db
):
"""Test deleting a user as superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a user to delete
async with AsyncTestingSessionLocal() as session:
@@ -522,14 +576,16 @@ class TestDeleteUser:
password_hash="hash",
first_name="Delete",
is_active=True,
is_superuser=False
is_superuser=False,
)
session.add(user_to_delete)
await session.commit()
await session.refresh(user_to_delete)
user_id = user_to_delete.id
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.delete(f"/api/v1/users/{user_id}", headers=headers)
@@ -540,6 +596,7 @@ class TestDeleteUser:
# Verify user is soft-deleted (has deleted_at timestamp)
async with AsyncTestingSessionLocal() as session:
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == user_id))
deleted_user = result.scalar_one_or_none()
assert deleted_user.deleted_at is not None
@@ -547,9 +604,13 @@ class TestDeleteUser:
@pytest.mark.asyncio
async def test_cannot_delete_self(self, client, async_test_superuser):
"""Test that users cannot delete their own account."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.delete(f"/api/v1/users/{async_test_superuser.id}", headers=headers)
response = await client.delete(
f"/api/v1/users/{async_test_superuser.id}", headers=headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -562,22 +623,28 @@ class TestDeleteUser:
password_hash="hash",
first_name="Protected",
is_active=True,
is_superuser=False
is_superuser=False,
)
test_db.add(other_user)
test_db.commit()
test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.delete(f"/api/v1/users/{other_user.id}", headers=headers)
response = await client.delete(
f"/api/v1/users/{other_user.id}", headers=headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_delete_nonexistent_user(self, client, async_test_superuser):
"""Test deleting non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
fake_id = uuid.uuid4()
response = await client.delete(f"/api/v1/users/{fake_id}", headers=headers)

View File

@@ -2,10 +2,12 @@
"""
Tests for user routes.
"""
from uuid import uuid4
import pytest
import pytest_asyncio
from fastapi import status
from uuid import uuid4
@pytest_asyncio.fixture
@@ -13,10 +15,7 @@ async def superuser_token(client, async_test_superuser):
"""Get access token for superuser."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "superuser@example.com",
"password": "SuperPassword123!"
}
json={"email": "superuser@example.com", "password": "SuperPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -27,10 +26,7 @@ async def user_token(client, async_test_user):
"""Get access token for regular user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -43,8 +39,7 @@ class TestListUsers:
async def test_list_users_success(self, client, superuser_token):
"""Test listing users successfully (covers lines 87-100)."""
response = await client.get(
"/api/v1/users",
headers={"Authorization": f"Bearer {superuser_token}"}
"/api/v1/users", headers={"Authorization": f"Bearer {superuser_token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -58,7 +53,7 @@ class TestListUsers:
"""Test listing users with is_superuser filter (covers line 74)."""
response = await client.get(
"/api/v1/users?is_superuser=true",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -73,8 +68,7 @@ class TestGetCurrentUser:
async def test_get_current_user_success(self, client, async_test_user, user_token):
"""Test getting current user profile."""
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -92,7 +86,7 @@ class TestUpdateCurrentUser:
response = await client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"},
json={"first_name": "UpdatedName"}
json={"first_name": "UpdatedName"},
)
assert response.status_code == status.HTTP_200_OK
@@ -104,12 +98,14 @@ class TestUpdateCurrentUser:
"""Test database error handling during update (covers lines 162-169)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("DB error")):
with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error")
):
with pytest.raises(Exception):
await client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
@pytest.mark.asyncio
@@ -118,7 +114,7 @@ class TestUpdateCurrentUser:
response = await client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"},
json={"is_superuser": True}
json={"is_superuser": True},
)
# Pydantic validation should reject this at the schema level
@@ -137,12 +133,15 @@ class TestUpdateCurrentUser:
"""Test ValueError handling during update (covers lines 165-166)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid value")):
with patch(
"app.api.routes.users.user_crud.update",
side_effect=ValueError("Invalid value"),
):
with pytest.raises(ValueError):
await client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
@@ -154,7 +153,7 @@ class TestGetUser:
"""Test getting user by ID."""
response = await client.get(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -167,7 +166,7 @@ class TestGetUser:
fake_id = uuid4()
response = await client.get(
f"/api/v1/users/{fake_id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -183,30 +182,34 @@ class TestUpdateUserById:
response = await client.patch(
f"/api/v1/users/{fake_id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(self, client, async_test_user, user_token):
async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(
self, client, async_test_user, user_token
):
"""Test non-superuser cannot modify superuser status (Pydantic validation)."""
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {user_token}"},
json={"is_superuser": True}
json={"is_superuser": True},
)
# Pydantic validation should reject this at the schema level
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_update_user_by_id_success(self, client, async_test_user, superuser_token):
async def test_update_user_by_id_success(
self, client, async_test_user, superuser_token
):
"""Test updating user successfully (covers lines 276-278)."""
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "SuperUpdated"}
json={"first_name": "SuperUpdated"},
)
assert response.status_code == status.HTTP_200_OK
@@ -214,29 +217,37 @@ class TestUpdateUserById:
assert data["first_name"] == "SuperUpdated"
@pytest.mark.asyncio
async def test_update_user_by_id_value_error(self, client, async_test_user, superuser_token):
async def test_update_user_by_id_value_error(
self, client, async_test_user, superuser_token
):
"""Test ValueError handling (covers lines 280-281)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid")):
with patch(
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid")
):
with pytest.raises(ValueError):
await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
@pytest.mark.asyncio
async def test_update_user_by_id_unexpected_error(self, client, async_test_user, superuser_token):
async def test_update_user_by_id_unexpected_error(
self, client, async_test_user, superuser_token
):
"""Test unexpected error handling (covers lines 283-284)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("Unexpected")):
with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected")
):
with pytest.raises(Exception):
await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
@@ -246,18 +257,18 @@ class TestChangePassword:
@pytest.mark.asyncio
async def test_change_password_success(self, client, async_test_db):
"""Test changing password successfully."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a fresh user
async with AsyncTestingSessionLocal() as session:
from app.models.user import User
from app.core.auth import get_password_hash
from app.models.user import User
new_user = User(
email="changepass@example.com",
password_hash=get_password_hash("OldPassword123!"),
first_name="Change",
last_name="Pass"
last_name="Pass",
)
session.add(new_user)
await session.commit()
@@ -265,10 +276,7 @@ class TestChangePassword:
# Login
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": "changepass@example.com",
"password": "OldPassword123!"
}
json={"email": "changepass@example.com", "password": "OldPassword123!"},
)
token = login_response.json()["access_token"]
@@ -278,8 +286,8 @@ class TestChangePassword:
headers={"Authorization": f"Bearer {token}"},
json={
"current_password": "OldPassword123!",
"new_password": "NewPassword456!"
}
"new_password": "NewPassword456!",
},
)
assert response.status_code == status.HTTP_200_OK
@@ -289,10 +297,7 @@ class TestChangePassword:
# Verify new password works
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": "changepass@example.com",
"password": "NewPassword456!"
}
json={"email": "changepass@example.com", "password": "NewPassword456!"},
)
assert login_response.status_code == status.HTTP_200_OK
@@ -306,7 +311,7 @@ class TestDeleteUserById:
fake_id = uuid4()
response = await client.delete(
f"/api/v1/users/{fake_id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -314,18 +319,18 @@ class TestDeleteUserById:
@pytest.mark.asyncio
async def test_delete_user_success(self, client, async_test_db, superuser_token):
"""Test deleting user successfully (covers lines 383-388)."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a user to delete
async with AsyncTestingSessionLocal() as session:
from app.models.user import User
from app.core.auth import get_password_hash
from app.models.user import User
user_to_delete = User(
email=f"delete{uuid4().hex[:8]}@example.com",
password_hash=get_password_hash("Password123!"),
first_name="Delete",
last_name="Me"
last_name="Me",
)
session.add(user_to_delete)
await session.commit()
@@ -334,7 +339,7 @@ class TestDeleteUserById:
response = await client.delete(
f"/api/v1/users/{user_id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -342,25 +347,35 @@ class TestDeleteUserById:
assert data["success"] is True
@pytest.mark.asyncio
async def test_delete_user_value_error(self, client, async_test_user, superuser_token):
async def test_delete_user_value_error(
self, client, async_test_user, superuser_token
):
"""Test ValueError handling during delete (covers lines 390-391)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=ValueError("Cannot delete")):
with patch(
"app.api.routes.users.user_crud.soft_delete",
side_effect=ValueError("Cannot delete"),
):
with pytest.raises(ValueError):
await client.delete(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
@pytest.mark.asyncio
async def test_delete_user_unexpected_error(self, client, async_test_user, superuser_token):
async def test_delete_user_unexpected_error(
self, client, async_test_user, superuser_token
):
"""Test unexpected error handling during delete (covers lines 393-394)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=Exception("Unexpected")):
with patch(
"app.api.routes.users.user_crud.soft_delete",
side_effect=Exception("Unexpected"),
):
with pytest.raises(Exception):
await client.delete(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)

View File

@@ -1,28 +1,32 @@
# tests/conftest.py
import os
import uuid
from datetime import datetime, timezone
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from httpx import ASGITransport, AsyncClient
# Set IS_TEST environment variable BEFORE importing app
# This prevents the scheduler from starting during tests
os.environ["IS_TEST"] = "True"
from app.main import app
from app.core.database import get_db
from app.models.user import User
from app.core.auth import get_password_hash
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
from app.core.database import get_db
from app.main import app
from app.models.user import User
from app.utils.test_utils import (
setup_async_test_db,
setup_test_db,
teardown_async_test_db,
teardown_test_db,
)
@pytest.fixture(scope="function")
def db_session():
"""
Creates a fresh SQLite in-memory database for each test function.
Yields a SQLAlchemy session that can be used for testing.
"""
# Set up the database
@@ -46,6 +50,7 @@ async def async_test_db():
yield test_engine, AsyncTestingSessionLocal
await teardown_async_test_db(test_engine)
@pytest.fixture
def user_create_data():
return {
@@ -55,7 +60,7 @@ def user_create_data():
"last_name": "User",
"phone_number": "+1234567890",
"is_superuser": False,
"preferences": None
"preferences": None,
}
@@ -102,7 +107,7 @@ async def client(async_test_db):
This overrides the get_db dependency to use the test database.
"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async def override_get_db():
async with AsyncTestingSessionLocal() as session:
@@ -176,7 +181,7 @@ async def async_test_user(async_test_db):
Password: TestPassword123
"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid.uuid4(),
@@ -202,7 +207,7 @@ async def async_test_superuser(async_test_db):
Password: SuperPassword123
"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid.uuid4(),
@@ -256,4 +261,4 @@ async def superuser_token(client, async_test_superuser):
)
assert response.status_code == 200, f"Login failed: {response.text}"
tokens = response.json()
return tokens["access_token"]
return tokens["access_token"]

View File

@@ -1,20 +1,20 @@
# tests/core/test_auth.py
import uuid
from datetime import UTC, datetime, timedelta
import pytest
from datetime import datetime, timedelta, timezone
from jose import jwt
from pydantic import ValidationError
from app.core.auth import (
verify_password,
get_password_hash,
TokenExpiredError,
TokenInvalidError,
TokenMissingClaimError,
create_access_token,
create_refresh_token,
decode_token,
get_password_hash,
get_token_data,
TokenExpiredError,
TokenInvalidError,
TokenMissingClaimError
verify_password,
)
from app.core.config import settings
@@ -58,15 +58,13 @@ class TestTokenCreation:
custom_claims = {
"email": "test@example.com",
"first_name": "Test",
"is_superuser": True
"is_superuser": True,
}
token = create_access_token(subject=user_id, claims=custom_claims)
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# Check standard claims
@@ -87,9 +85,7 @@ class TestTokenCreation:
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# Check standard claims
@@ -105,23 +101,18 @@ class TestTokenCreation:
expires = timedelta(minutes=5)
# Create token with specific expiration
token = create_access_token(
subject=user_id,
expires_delta=expires
)
token = create_access_token(subject=user_id, expires_delta=expires)
# Decode token
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# Get actual expiration time from token
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
expiration = datetime.fromtimestamp(payload["exp"], tz=UTC)
# Calculate expected expiration (approximately)
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
expected_expiration = now + expires
# Difference should be small (less than 1 second)
@@ -148,7 +139,7 @@ class TestTokenDecoding:
user_id = str(uuid.uuid4())
# Create a token that's already expired by directly manipulating the payload
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
expired_time = now - timedelta(hours=1) # 1 hour in the past
# Create the expired token manually
@@ -157,13 +148,11 @@ class TestTokenDecoding:
"exp": int(expired_time.timestamp()), # Set expiration in the past
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
"type": "access",
}
expired_token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
# Attempting to decode should raise TokenExpiredError
@@ -180,20 +169,16 @@ class TestTokenDecoding:
def test_decode_token_with_missing_sub(self):
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
# Create a token without a subject
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
payload = {
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
"type": "access",
# No 'sub' claim
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
with pytest.raises(TokenMissingClaimError):
decode_token(token)
@@ -211,20 +196,16 @@ class TestTokenDecoding:
"""Test that a token with invalid payload structure raises TokenInvalidError"""
# Create a token with an invalid payload structure - missing 'sub' which is required
# but including 'exp' to avoid the expiration check
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
payload = {
# Missing "sub" field which is required
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"invalid_field": "test"
"invalid_field": "test",
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
# Should raise TokenMissingClaimError due to missing 'sub'
with pytest.raises(TokenMissingClaimError):
@@ -236,11 +217,7 @@ class TestTokenDecoding:
"exp": int((now + timedelta(minutes=30)).timestamp()),
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
# Should raise TokenInvalidError due to ValidationError
with pytest.raises(TokenInvalidError):
@@ -249,12 +226,9 @@ class TestTokenDecoding:
def test_get_token_data(self):
"""Test extracting TokenData from a token"""
user_id = uuid.uuid4()
token = create_access_token(
subject=str(user_id),
claims={"is_superuser": True}
)
token = create_access_token(subject=str(user_id), claims={"is_superuser": True})
token_data = get_token_data(token)
assert token_data.user_id == user_id
assert token_data.is_superuser is True
assert token_data.is_superuser is True

View File

@@ -8,11 +8,11 @@ Critical security tests covering:
These tests cover critical security vulnerabilities that could be exploited.
"""
import pytest
from jose import jwt
from datetime import datetime, timedelta, timezone
from app.core.auth import decode_token, create_access_token, TokenInvalidError
from app.core.auth import TokenInvalidError, create_access_token, decode_token
from app.core.config import settings
@@ -46,13 +46,14 @@ class TestJWTAlgorithmSecurityAttacks:
"""
# Create a payload that would normally be valid (using timestamps)
import time
now = int(time.time())
payload = {
"sub": "user123",
"exp": now + 3600, # 1 hour from now
"iat": now,
"type": "access"
"type": "access",
}
# Craft a malicious token with "alg: none"
@@ -61,13 +62,13 @@ class TestJWTAlgorithmSecurityAttacks:
import json
header = {"alg": "none", "typ": "JWT"}
header_encoded = base64.urlsafe_b64encode(
json.dumps(header).encode()
).decode().rstrip("=")
header_encoded = (
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
)
payload_encoded = base64.urlsafe_b64encode(
json.dumps(payload).encode()
).decode().rstrip("=")
payload_encoded = (
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
)
# Token with no signature (algorithm "none")
malicious_token = f"{header_encoded}.{payload_encoded}."
@@ -85,22 +86,17 @@ class TestJWTAlgorithmSecurityAttacks:
import time
now = int(time.time())
payload = {
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
# Try uppercase "NONE"
header = {"alg": "NONE", "typ": "JWT"}
header_encoded = base64.urlsafe_b64encode(
json.dumps(header).encode()
).decode().rstrip("=")
header_encoded = (
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
)
payload_encoded = base64.urlsafe_b64encode(
json.dumps(payload).encode()
).decode().rstrip("=")
payload_encoded = (
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
)
malicious_token = f"{header_encoded}.{payload_encoded}."
@@ -121,15 +117,11 @@ class TestJWTAlgorithmSecurityAttacks:
before our defensive checks at line 212. This is good for security!
"""
import time
now = int(time.time())
# Create a valid payload
payload = {
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
# Encode with wrong algorithm (RS256 instead of HS256)
# This simulates an attacker trying algorithm substitution
@@ -137,9 +129,7 @@ class TestJWTAlgorithmSecurityAttacks:
try:
malicious_token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=wrong_algorithm
payload, settings.SECRET_KEY, algorithm=wrong_algorithm
)
# Should reject the token (library catches mismatch)
@@ -156,21 +146,15 @@ class TestJWTAlgorithmSecurityAttacks:
Prevents algorithm downgrade/upgrade attacks.
"""
import time
now = int(time.time())
payload = {
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
# Create token with HS384 instead of HS256
try:
malicious_token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm="HS384"
payload, settings.SECRET_KEY, algorithm="HS384"
)
with pytest.raises(TokenInvalidError):
@@ -223,20 +207,15 @@ class TestJWTSecurityEdgeCases:
# Create token without "alg" in header
header = {"typ": "JWT"} # Missing "alg"
payload = {
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
header_encoded = base64.urlsafe_b64encode(
json.dumps(header).encode()
).decode().rstrip("=")
header_encoded = (
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
)
payload_encoded = base64.urlsafe_b64encode(
json.dumps(payload).encode()
).decode().rstrip("=")
payload_encoded = (
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
)
malicious_token = f"{header_encoded}.{payload_encoded}.fake_signature"
@@ -253,15 +232,20 @@ class TestJWTSecurityEdgeCases:
"""Test token with malformed JSON in payload."""
import base64
header = {"alg": "HS256", "typ": "JWT"}
header_encoded = base64.urlsafe_b64encode(
b'{"alg":"HS256","typ":"JWT"}'
).decode().rstrip("=")
header_encoded = (
base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}')
.decode()
.rstrip("=")
)
# Invalid JSON (missing closing brace)
invalid_payload_encoded = base64.urlsafe_b64encode(
b'{"sub":"user123"' # Invalid JSON
).decode().rstrip("=")
invalid_payload_encoded = (
base64.urlsafe_b64encode(
b'{"sub":"user123"' # Invalid JSON
)
.decode()
.rstrip("=")
)
malicious_token = f"{header_encoded}.{invalid_payload_encoded}.fake_sig"

View File

@@ -1,6 +1,7 @@
# tests/core/test_config.py
import pytest
from pydantic import ValidationError
from app.core.config import Settings
@@ -22,11 +23,15 @@ class TestSecretKeyValidation:
with pytest.raises(ValidationError) as exc_info:
Settings(SECRET_KEY=default_key, ENVIRONMENT="production")
assert "must be set to a secure random value in production" in str(exc_info.value)
assert "must be set to a secure random value in production" in str(
exc_info.value
)
def test_default_secret_key_in_development_allows_with_warning(self, caplog):
"""Test that default SECRET_KEY in development is allowed but warns"""
settings = Settings(SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development")
settings = Settings(
SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development"
)
assert settings.SECRET_KEY == "your_secret_key_here" + "x" * 14
# Note: The warning happens during validation, which we've seen works
@@ -44,19 +49,13 @@ class TestSuperuserPasswordValidation:
def test_none_password_accepted(self):
"""Test that None password is accepted (optional field)"""
settings = Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD=None
)
settings = Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=None)
assert settings.FIRST_SUPERUSER_PASSWORD is None
def test_password_too_short_raises_error(self):
"""Test that password shorter than 12 characters raises error"""
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="Short1"
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="Short1")
assert "must be at least 12 characters" in str(exc_info.value)
@@ -64,14 +63,11 @@ class TestSuperuserPasswordValidation:
"""Test that common weak passwords are rejected"""
# Test with the exact weak passwords from the validator
# These are in the weak_passwords set and should be rejected
weak_passwords = ['123456789012'] # Exactly 12 chars, in the weak set
weak_passwords = ["123456789012"] # Exactly 12 chars, in the weak set
for weak_pwd in weak_passwords:
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD=weak_pwd
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=weak_pwd)
# Should get "too weak" message
error_str = str(exc_info.value)
assert "too weak" in error_str
@@ -79,30 +75,21 @@ class TestSuperuserPasswordValidation:
def test_password_without_lowercase_rejected(self):
"""Test that password without lowercase is rejected"""
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123"
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123")
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
def test_password_without_uppercase_rejected(self):
"""Test that password without uppercase is rejected"""
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="alllowercase123"
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="alllowercase123")
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
def test_password_without_digit_rejected(self):
"""Test that password without digit is rejected"""
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="NoDigitsHere"
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="NoDigitsHere")
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
@@ -110,8 +97,7 @@ class TestSuperuserPasswordValidation:
"""Test that strong password is accepted"""
strong_password = "StrongPassword123!"
settings = Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD=strong_password
SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=strong_password
)
assert settings.FIRST_SUPERUSER_PASSWORD == strong_password
@@ -150,7 +136,7 @@ class TestDatabaseConfiguration:
POSTGRES_HOST="testhost",
POSTGRES_PORT="5432",
POSTGRES_DB="testdb",
DATABASE_URL=None # Don't use explicit URL
DATABASE_URL=None, # Don't use explicit URL
)
expected_url = "postgresql://testuser:testpass@testhost:5432/testdb"
@@ -159,10 +145,7 @@ class TestDatabaseConfiguration:
def test_explicit_database_url_used_when_set(self):
"""Test that explicit DATABASE_URL is used when provided"""
explicit_url = "postgresql://explicit:pass@host:5432/db"
settings = Settings(
SECRET_KEY="a" * 32,
DATABASE_URL=explicit_url
)
settings = Settings(SECRET_KEY="a" * 32, DATABASE_URL=explicit_url)
assert settings.database_url == explicit_url

View File

@@ -6,8 +6,10 @@ Critical security tests covering:
These tests prevent security misconfigurations.
"""
import pytest
import os
import pytest
from pydantic import ValidationError
@@ -43,6 +45,7 @@ class TestSecretKeySecurityValidation:
# Import Settings class fresh (to pick up new env var)
# The ValidationError should be raised during reload when Settings() is instantiated
import importlib
from app.core import config
# Reload will raise ValidationError because Settings() is instantiated at module level
@@ -58,7 +61,9 @@ class TestSecretKeySecurityValidation:
# Reload config to restore original settings
import importlib
from app.core import config
importlib.reload(config)
def test_secret_key_exactly_32_characters_accepted(self):
@@ -75,7 +80,9 @@ class TestSecretKeySecurityValidation:
os.environ["SECRET_KEY"] = key_32
import importlib
from app.core import config
importlib.reload(config)
# Should work
@@ -89,7 +96,9 @@ class TestSecretKeySecurityValidation:
os.environ.pop("SECRET_KEY", None)
import importlib
from app.core import config
importlib.reload(config)
def test_secret_key_long_enough_accepted(self):
@@ -106,7 +115,9 @@ class TestSecretKeySecurityValidation:
os.environ["SECRET_KEY"] = key_64
import importlib
from app.core import config
importlib.reload(config)
# Should work
@@ -120,7 +131,9 @@ class TestSecretKeySecurityValidation:
os.environ.pop("SECRET_KEY", None)
import importlib
from app.core import config
importlib.reload(config)
def test_default_secret_key_meets_requirements(self):
@@ -132,4 +145,6 @@ class TestSecretKeySecurityValidation:
from app.core.config import settings
# Current settings should have valid SECRET_KEY
assert len(settings.SECRET_KEY) >= 32, "Default SECRET_KEY must be at least 32 chars"
assert len(settings.SECRET_KEY) >= 32, (
"Default SECRET_KEY must be at least 32 chars"
)

View File

@@ -9,18 +9,19 @@ Covers:
- init_async_db
- close_async_db
"""
from unittest.mock import patch
import pytest
import pytest_asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import (
get_async_database_url,
get_db,
async_transaction_scope,
check_async_database_health,
init_async_db,
close_async_db,
get_async_database_url,
get_db,
init_async_db,
)
@@ -88,12 +89,13 @@ class TestAsyncTransactionScope:
async def test_transaction_scope_commits_on_success(self, async_test_db):
"""Test that successful operations are committed (covers line 138)."""
# Mock the transaction scope to use test database
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
with patch('app.core.database.SessionLocal', SessionLocal):
with patch("app.core.database.SessionLocal", SessionLocal):
async with async_transaction_scope() as db:
# Execute a simple query to verify transaction works
from sqlalchemy import text
result = await db.execute(text("SELECT 1"))
assert result is not None
# Transaction should be committed (covers line 138 debug log)
@@ -101,12 +103,13 @@ class TestAsyncTransactionScope:
@pytest.mark.asyncio
async def test_transaction_scope_rollback_on_error(self, async_test_db):
"""Test that transaction rolls back on exception."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
with patch('app.core.database.SessionLocal', SessionLocal):
with patch("app.core.database.SessionLocal", SessionLocal):
with pytest.raises(RuntimeError, match="Test error"):
async with async_transaction_scope() as db:
from sqlalchemy import text
await db.execute(text("SELECT 1"))
raise RuntimeError("Test error")
@@ -117,9 +120,9 @@ class TestCheckAsyncDatabaseHealth:
@pytest.mark.asyncio
async def test_database_health_check_success(self, async_test_db):
"""Test health check returns True on success (covers line 156)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
with patch('app.core.database.SessionLocal', SessionLocal):
with patch("app.core.database.SessionLocal", SessionLocal):
result = await check_async_database_health()
assert result is True
@@ -127,7 +130,7 @@ class TestCheckAsyncDatabaseHealth:
async def test_database_health_check_failure(self):
"""Test health check returns False on database error."""
# Mock async_transaction_scope to raise an error
with patch('app.core.database.async_transaction_scope') as mock_scope:
with patch("app.core.database.async_transaction_scope") as mock_scope:
mock_scope.side_effect = Exception("Database connection failed")
result = await check_async_database_health()
@@ -140,10 +143,10 @@ class TestInitAsyncDb:
@pytest.mark.asyncio
async def test_init_async_db_creates_tables(self, async_test_db):
"""Test init_async_db creates tables (covers lines 174-176)."""
test_engine, SessionLocal = async_test_db
test_engine, _SessionLocal = async_test_db
# Mock the engine to use test engine
with patch('app.core.database.engine', test_engine):
with patch("app.core.database.engine", test_engine):
await init_async_db()
# If no exception, tables were created successfully
@@ -155,7 +158,6 @@ class TestCloseAsyncDb:
async def test_close_async_db_disposes_engine(self):
"""Test close_async_db disposes engine (covers lines 185-186)."""
# Create a fresh engine to test closing
from app.core.database import engine
# Close connections
await close_async_db()

View File

@@ -2,14 +2,16 @@
"""
Comprehensive tests for CRUDBase class covering all error paths and edge cases.
"""
from datetime import UTC
from unittest.mock import patch
from uuid import uuid4
import pytest
from uuid import uuid4, UUID
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.orm import joinedload
from unittest.mock import AsyncMock, patch, MagicMock
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
@@ -19,7 +21,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_invalid_uuid_string(self, async_test_db):
"""Test get with invalid UUID string returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.get(session, id="invalid-uuid")
@@ -28,7 +30,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_invalid_uuid_type(self, async_test_db):
"""Test get with invalid UUID type returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.get(session, id=12345) # int instead of UUID
@@ -37,7 +39,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_uuid_object(self, async_test_db, async_test_user):
"""Test get with UUID object instead of string."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Pass UUID object directly
@@ -48,26 +50,24 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_options(self, async_test_db, async_test_user):
"""Test get with eager loading options (tests lines 76-78)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Test that options parameter is accepted and doesn't error
# We pass an empty list which still tests the code path
result = await user_crud.get(
session,
id=str(async_test_user.id),
options=[]
session, id=str(async_test_user.id), options=[]
)
assert result is not None
@pytest.mark.asyncio
async def test_get_database_error(self, async_test_db):
"""Test get handles database errors properly."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock execute to raise an exception
with patch.object(session, 'execute', side_effect=Exception("DB error")):
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.get(session, id=str(uuid4()))
@@ -78,7 +78,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_negative_skip(self, async_test_db):
"""Test get_multi with negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -87,7 +87,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_negative_limit(self, async_test_db):
"""Test get_multi with negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -96,7 +96,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_limit_too_large(self, async_test_db):
"""Test get_multi with limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
@@ -105,25 +105,20 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_with_options(self, async_test_db, async_test_user):
"""Test get_multi with eager loading options (tests lines 118-120)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Test that options parameter is accepted
results = await user_crud.get_multi(
session,
skip=0,
limit=10,
options=[]
)
results = await user_crud.get_multi(session, skip=0, limit=10, options=[])
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_multi_database_error(self, async_test_db):
"""Test get_multi handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("DB error")):
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.get_multi(session)
@@ -134,7 +129,7 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_duplicate_unique_field(self, async_test_db, async_test_user):
"""Test create with duplicate unique field raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Try to create user with duplicate email
@@ -142,7 +137,7 @@ class TestCRUDBaseCreate:
email=async_test_user.email, # Duplicate!
password="TestPassword123!",
first_name="Test",
last_name="Duplicate"
last_name="Duplicate",
)
with pytest.raises(ValueError, match="already exists"):
@@ -151,22 +146,23 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_integrity_error_non_duplicate(self, async_test_db):
"""Test create with non-duplicate IntegrityError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock commit to raise IntegrityError without "unique" in message
original_commit = session.commit
async def mock_commit():
error = IntegrityError("statement", {}, Exception("foreign key violation"))
error = IntegrityError(
"statement", {}, Exception("foreign key violation")
)
raise error
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, "commit", side_effect=mock_commit):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
with pytest.raises(ValueError, match="Database integrity error"):
@@ -175,15 +171,21 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_operational_error(self, async_test_db):
"""Test create with OperationalError (user CRUD catches as generic Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection lost"))):
with patch.object(
session,
"commit",
side_effect=OperationalError(
"statement", {}, Exception("connection lost")
),
):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -193,15 +195,19 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_data_error(self, async_test_db):
"""Test create with DataError (user CRUD catches as generic Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=DataError("statement", {}, Exception("invalid data"))):
with patch.object(
session,
"commit",
side_effect=DataError("statement", {}, Exception("invalid data")),
):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -211,15 +217,17 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db):
"""Test create with unexpected exception."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected error")):
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected error")
):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
with pytest.raises(RuntimeError, match="Unexpected error"):
@@ -232,16 +240,17 @@ class TestCRUDBaseUpdate:
@pytest.mark.asyncio
async def test_update_duplicate_unique_field(self, async_test_db, async_test_user):
"""Test update with duplicate unique field raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create another user
async with SessionLocal() as session:
from app.crud.user import user as user_crud
user2_data = UserCreate(
email="user2@example.com",
password="TestPassword123!",
first_name="User",
last_name="Two"
last_name="Two",
)
user2 = await user_crud.create(session, obj_in=user2_data)
await session.commit()
@@ -250,63 +259,89 @@ class TestCRUDBaseUpdate:
async with SessionLocal() as session:
user2_obj = await user_crud.get(session, id=str(user2.id))
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("UNIQUE constraint failed"))):
with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("UNIQUE constraint failed")
),
):
update_data = UserUpdate(email=async_test_user.email)
with pytest.raises(ValueError, match="already exists"):
await user_crud.update(session, db_obj=user2_obj, obj_in=update_data)
await user_crud.update(
session, db_obj=user2_obj, obj_in=update_data
)
@pytest.mark.asyncio
async def test_update_with_dict(self, async_test_db, async_test_user):
"""Test update with dict instead of schema."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
# Update with dict (tests lines 164-165)
updated = await user_crud.update(
session,
db_obj=user,
obj_in={"first_name": "UpdatedName"}
session, db_obj=user, obj_in={"first_name": "UpdatedName"}
)
assert updated.first_name == "UpdatedName"
@pytest.mark.asyncio
async def test_update_integrity_error(self, async_test_db, async_test_user):
"""Test update with IntegrityError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("constraint failed"))):
with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("constraint failed")
),
):
with pytest.raises(ValueError, match="Database integrity error"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user):
"""Test update with OperationalError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection error"))):
with patch.object(
session,
"commit",
side_effect=OperationalError(
"statement", {}, Exception("connection error")
),
):
with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@pytest.mark.asyncio
async def test_update_unexpected_error(self, async_test_db, async_test_user):
"""Test update with unexpected error."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
class TestCRUDBaseRemove:
@@ -315,7 +350,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_invalid_uuid(self, async_test_db):
"""Test remove with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.remove(session, id="invalid-uuid")
@@ -324,7 +359,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_with_uuid_object(self, async_test_db, async_test_user):
"""Test remove with UUID object."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a user to delete
async with SessionLocal() as session:
@@ -332,7 +367,7 @@ class TestCRUDBaseRemove:
email="todelete@example.com",
password="TestPassword123!",
first_name="To",
last_name="Delete"
last_name="Delete",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -347,7 +382,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_nonexistent(self, async_test_db):
"""Test remove of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.remove(session, id=str(uuid4()))
@@ -356,21 +391,31 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_integrity_error(self, async_test_db, async_test_user):
"""Test remove with IntegrityError (foreign key constraint)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock delete to raise IntegrityError
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("FOREIGN KEY constraint"))):
with pytest.raises(ValueError, match="Cannot delete.*referenced by other records"):
with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("FOREIGN KEY constraint")
),
):
with pytest.raises(
ValueError, match="Cannot delete.*referenced by other records"
):
await user_crud.remove(session, id=str(async_test_user.id))
@pytest.mark.asyncio
async def test_remove_unexpected_error(self, async_test_db, async_test_user):
"""Test remove with unexpected error."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError):
await user_crud.remove(session, id=str(async_test_user.id))
@@ -381,10 +426,12 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
"""Test get_multi_with_total basic functionality."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total(session, skip=0, limit=10)
items, total = await user_crud.get_multi_with_total(
session, skip=0, limit=10
)
assert isinstance(items, list)
assert isinstance(total, int)
assert total >= 1 # At least the test user
@@ -392,7 +439,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_skip(self, async_test_db):
"""Test get_multi_with_total with negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -401,7 +448,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db):
"""Test get_multi_with_total with negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -410,28 +457,34 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
"""Test get_multi_with_total with limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, limit=1001)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_filters(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with filters."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
filters = {"email": async_test_user.email}
items, total = await user_crud.get_multi_with_total(session, filters=filters)
items, total = await user_crud.get_multi_with_total(
session, filters=filters
)
assert total == 1
assert len(items) == 1
assert items[0].email == async_test_user.email
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_sorting_asc(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with ascending sort."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create additional users
async with SessionLocal() as session:
@@ -439,13 +492,13 @@ class TestCRUDBaseGetMultiWithTotal:
email="aaa@example.com",
password="TestPassword123!",
first_name="AAA",
last_name="User"
last_name="User",
)
user_data2 = UserCreate(
email="zzz@example.com",
password="TestPassword123!",
first_name="ZZZ",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
@@ -460,9 +513,11 @@ class TestCRUDBaseGetMultiWithTotal:
assert items[0].email == "aaa@example.com"
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_sorting_desc(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with descending sort."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create additional users
async with SessionLocal() as session:
@@ -470,20 +525,20 @@ class TestCRUDBaseGetMultiWithTotal:
email="bbb@example.com",
password="TestPassword123!",
first_name="BBB",
last_name="User"
last_name="User",
)
user_data2 = UserCreate(
email="ccc@example.com",
password="TestPassword123!",
first_name="CCC",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
await session.commit()
async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total(
items, _total = await user_crud.get_multi_with_total(
session, sort_by="email", sort_order="desc", limit=1
)
assert len(items) == 1
@@ -492,7 +547,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_with_pagination(self, async_test_db):
"""Test get_multi_with_total pagination works correctly."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create minimal users for pagination test (3 instead of 5)
async with SessionLocal() as session:
@@ -501,19 +556,23 @@ class TestCRUDBaseGetMultiWithTotal:
email=f"user{i}@example.com",
password="TestPassword123!",
first_name=f"User{i}",
last_name="Test"
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
await session.commit()
async with SessionLocal() as session:
# Get first page
items1, total = await user_crud.get_multi_with_total(session, skip=0, limit=2)
items1, total = await user_crud.get_multi_with_total(
session, skip=0, limit=2
)
assert len(items1) == 2
assert total >= 3
# Get second page
items2, total2 = await user_crud.get_multi_with_total(session, skip=2, limit=2)
items2, total2 = await user_crud.get_multi_with_total(
session, skip=2, limit=2
)
assert len(items2) >= 1
assert total2 == total
@@ -529,7 +588,7 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio
async def test_count_basic(self, async_test_db, async_test_user):
"""Test count returns correct number."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
count = await user_crud.count(session)
@@ -539,7 +598,7 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio
async def test_count_multiple_users(self, async_test_db, async_test_user):
"""Test count with multiple users."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create additional users
async with SessionLocal() as session:
@@ -549,13 +608,13 @@ class TestCRUDBaseCount:
email="count1@example.com",
password="TestPassword123!",
first_name="Count",
last_name="One"
last_name="One",
)
user_data2 = UserCreate(
email="count2@example.com",
password="TestPassword123!",
first_name="Count",
last_name="Two"
last_name="Two",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
@@ -568,10 +627,10 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio
async def test_count_database_error(self, async_test_db):
"""Test count handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("DB error")):
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.count(session)
@@ -582,7 +641,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio
async def test_exists_true(self, async_test_db, async_test_user):
"""Test exists returns True for existing record."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(async_test_user.id))
@@ -591,7 +650,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio
async def test_exists_false(self, async_test_db):
"""Test exists returns False for non-existent record."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(uuid4()))
@@ -600,7 +659,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio
async def test_exists_invalid_uuid(self, async_test_db):
"""Test exists returns False for invalid UUID."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id="invalid-uuid")
@@ -613,7 +672,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_success(self, async_test_db):
"""Test soft delete sets deleted_at timestamp."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a user to soft delete
async with SessionLocal() as session:
@@ -621,7 +680,7 @@ class TestCRUDBaseSoftDelete:
email="softdelete@example.com",
password="TestPassword123!",
first_name="Soft",
last_name="Delete"
last_name="Delete",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -636,7 +695,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_invalid_uuid(self, async_test_db):
"""Test soft delete with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id="invalid-uuid")
@@ -645,7 +704,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_nonexistent(self, async_test_db):
"""Test soft delete of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id=str(uuid4()))
@@ -654,7 +713,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_with_uuid_object(self, async_test_db):
"""Test soft delete with UUID object."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a user to soft delete
async with SessionLocal() as session:
@@ -662,7 +721,7 @@ class TestCRUDBaseSoftDelete:
email="softdelete2@example.com",
password="TestPassword123!",
first_name="Soft",
last_name="Delete2"
last_name="Delete2",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -681,7 +740,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_success(self, async_test_db):
"""Test restore clears deleted_at timestamp."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create and soft delete a user
async with SessionLocal() as session:
@@ -689,7 +748,7 @@ class TestCRUDBaseRestore:
email="restore@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test"
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -707,7 +766,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_invalid_uuid(self, async_test_db):
"""Test restore with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.restore(session, id="invalid-uuid")
@@ -716,7 +775,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_nonexistent(self, async_test_db):
"""Test restore of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.restore(session, id=str(uuid4()))
@@ -725,7 +784,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_not_deleted(self, async_test_db, async_test_user):
"""Test restore of non-deleted record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Try to restore a user that's not deleted
@@ -735,7 +794,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_with_uuid_object(self, async_test_db):
"""Test restore with UUID object."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create and soft delete a user
async with SessionLocal() as session:
@@ -743,7 +802,7 @@ class TestCRUDBaseRestore:
email="restore2@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test2"
last_name="Test2",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -765,7 +824,7 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_skip(self, async_test_db):
"""Test that negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -774,7 +833,7 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db):
"""Test that negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -783,23 +842,22 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
"""Test that limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_filters(
self, async_test_db, async_test_user
):
"""Test pagination with filters (covers lines 270-273)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
filters={"is_active": True}
session, skip=0, limit=10, filters={"is_active": True}
)
assert isinstance(users, list)
assert total >= 0
@@ -807,30 +865,22 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db):
"""Test pagination with descending sort (covers lines 283-284)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="created_at",
sort_order="desc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
)
assert isinstance(users, list)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db):
"""Test pagination with ascending sort (covers lines 285-286)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="created_at",
sort_order="asc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
)
assert isinstance(users, list)
@@ -842,13 +892,15 @@ class TestCRUDBaseModelsWithoutSoftDelete:
"""
@pytest.mark.asyncio
async def test_soft_delete_model_without_deleted_at(self, async_test_db, async_test_user):
async def test_soft_delete_model_without_deleted_at(
self, async_test_db, async_test_user
):
"""Test soft_delete on Organization model (no deleted_at) raises ValueError (covers lines 342-343)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.crud.organization import organization as org_crud
from app.models.organization import Organization
async with SessionLocal() as session:
org = Organization(name="Test Org", slug="test-org")
@@ -864,11 +916,11 @@ class TestCRUDBaseModelsWithoutSoftDelete:
@pytest.mark.asyncio
async def test_restore_model_without_deleted_at(self, async_test_db):
"""Test restore on Organization model (no deleted_at) raises ValueError (covers lines 383-384)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.crud.organization import organization as org_crud
from app.models.organization import Organization
async with SessionLocal() as session:
org = Organization(name="Restore Test", slug="restore-test")
@@ -889,14 +941,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
"""
@pytest.mark.asyncio
async def test_get_with_real_eager_loading_options(self, async_test_db, async_test_user):
async def test_get_with_real_eager_loading_options(
self, async_test_db, async_test_user
):
"""Test get() with actual eager loading options (covers lines 77-78)."""
from datetime import datetime, timedelta, timezone
test_engine, SessionLocal = async_test_db
from datetime import datetime, timedelta
_test_engine, SessionLocal = async_test_db
# Create a session for the user
from app.models.user_session import UserSession
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
async with SessionLocal() as session:
user_session = UserSession(
@@ -905,8 +960,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
device_id="test-device",
ip_address="192.168.1.1",
user_agent="Test Agent",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60),
)
session.add(user_session)
await session.commit()
@@ -917,7 +972,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
result = await session_crud.get(
session,
id=str(session_id),
options=[joinedload(UserSession.user)] # Real option, not empty list
options=[joinedload(UserSession.user)], # Real option, not empty list
)
assert result is not None
assert result.id == session_id
@@ -925,14 +980,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
assert result.user.email == async_test_user.email
@pytest.mark.asyncio
async def test_get_multi_with_real_eager_loading_options(self, async_test_db, async_test_user):
async def test_get_multi_with_real_eager_loading_options(
self, async_test_db, async_test_user
):
"""Test get_multi() with actual eager loading options (covers lines 119-120)."""
from datetime import datetime, timedelta, timezone
test_engine, SessionLocal = async_test_db
from datetime import datetime, timedelta
_test_engine, SessionLocal = async_test_db
# Create multiple sessions for the user
from app.models.user_session import UserSession
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
async with SessionLocal() as session:
for i in range(3):
@@ -942,8 +1000,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
device_id=f"device-{i}",
ip_address=f"192.168.1.{i}",
user_agent=f"Agent {i}",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60),
)
session.add(user_session)
await session.commit()
@@ -954,7 +1012,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
session,
skip=0,
limit=10,
options=[joinedload(UserSession.user)] # Real option, not empty list
options=[joinedload(UserSession.user)], # Real option, not empty list
)
assert len(results) >= 3
# Verify we can access user without additional queries

View File

@@ -3,13 +3,15 @@
Comprehensive tests for base CRUD database failure scenarios.
Tests exception handling, rollbacks, and error messages.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.exc import DataError, OperationalError
from app.crud.user import user as user_crud
from app.schemas.users import UserCreate, UserUpdate
from app.schemas.users import UserCreate
class TestBaseCRUDCreateFailures:
@@ -18,19 +20,24 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio
async def test_create_operational_error_triggers_rollback(self, async_test_db):
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Connection lost", {}, Exception("DB connection failed"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
async def mock_commit():
raise OperationalError(
"Connection lost", {}, Exception("DB connection failed")
)
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="operror@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -43,19 +50,22 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio
async def test_create_data_error_triggers_rollback(self, async_test_db):
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise DataError("Invalid data type", {}, Exception("Data overflow"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="dataerror@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -67,19 +77,22 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio
async def test_create_unexpected_exception_triggers_rollback(self, async_test_db):
"""Test that unexpected exceptions trigger rollback and re-raise."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Unexpected database error")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="unexpected@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
with pytest.raises(RuntimeError, match="Unexpected database error"):
@@ -94,7 +107,7 @@ class TestBaseCRUDUpdateFailures:
@pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user):
"""Test update with OperationalError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -102,17 +115,21 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit():
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_data_error(self, async_test_db, async_test_user):
"""Test update with DataError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -120,17 +137,21 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit():
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_unexpected_error(self, async_test_db, async_test_user):
"""Test update with unexpected error."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -138,10 +159,14 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit():
raise KeyError("Unexpected error")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(KeyError):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@@ -150,16 +175,21 @@ class TestBaseCRUDRemoveFailures:
"""Test base CRUD remove method exception handling."""
@pytest.mark.asyncio
async def test_remove_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
async def test_remove_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test that unexpected errors in remove trigger rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Database write failed")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Database write failed"):
await user_crud.remove(session, id=str(async_test_user.id))
@@ -172,16 +202,15 @@ class TestBaseCRUDGetMultiWithTotalFailures:
@pytest.mark.asyncio
async def test_get_multi_with_total_database_error(self, async_test_db):
"""Test get_multi_with_total handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock execute to raise an error
original_execute = session.execute
async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("Database error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi_with_total(session, skip=0, limit=10)
@@ -192,13 +221,14 @@ class TestBaseCRUDCountFailures:
@pytest.mark.asyncio
async def test_count_database_error_propagates(self, async_test_db):
"""Test count propagates database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Count failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.count(session)
@@ -207,16 +237,21 @@ class TestBaseCRUDSoftDeleteFailures:
"""Test soft_delete method exception handling."""
@pytest.mark.asyncio
async def test_soft_delete_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
async def test_soft_delete_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test soft_delete handles unexpected errors with rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Soft delete failed")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Soft delete failed"):
await user_crud.soft_delete(session, id=str(async_test_user.id))
@@ -229,7 +264,7 @@ class TestBaseCRUDRestoreFailures:
@pytest.mark.asyncio
async def test_restore_unexpected_error_triggers_rollback(self, async_test_db):
"""Test restore handles unexpected errors with rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# First create and soft delete a user
async with SessionLocal() as session:
@@ -237,7 +272,7 @@ class TestBaseCRUDRestoreFailures:
email="restore_test@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test"
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -248,11 +283,14 @@ class TestBaseCRUDRestoreFailures:
# Now test restore failure
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Restore failed")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Restore failed"):
await user_crud.restore(session, id=str(user_id))
@@ -265,13 +303,14 @@ class TestBaseCRUDGetFailures:
@pytest.mark.asyncio
async def test_get_database_error_propagates(self, async_test_db):
"""Test get propagates database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Get failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get(session, id=str(uuid4()))
@@ -282,12 +321,13 @@ class TestBaseCRUDGetMultiFailures:
@pytest.mark.asyncio
async def test_get_multi_database_error_propagates(self, async_test_db):
"""Test get_multi propagates database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi(session, skip=0, limit=10)

File diff suppressed because it is too large Load Diff

View File

@@ -2,10 +2,12 @@
"""
Comprehensive tests for async session CRUD operations.
"""
import pytest
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate
@@ -17,7 +19,7 @@ class TestGetByJti:
@pytest.mark.asyncio
async def test_get_by_jti_success(self, async_test_db, async_test_user):
"""Test getting session by JTI."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -27,8 +29,8 @@ class TestGetByJti:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -41,7 +43,7 @@ class TestGetByJti:
@pytest.mark.asyncio
async def test_get_by_jti_not_found(self, async_test_db):
"""Test getting non-existent JTI returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="nonexistent")
@@ -54,7 +56,7 @@ class TestGetActiveByJti:
@pytest.mark.asyncio
async def test_get_active_by_jti_success(self, async_test_db, async_test_user):
"""Test getting active session by JTI."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -64,8 +66,8 @@ class TestGetActiveByJti:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -78,7 +80,7 @@ class TestGetActiveByJti:
@pytest.mark.asyncio
async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user):
"""Test getting inactive session by JTI returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -88,8 +90,8 @@ class TestGetActiveByJti:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -105,7 +107,7 @@ class TestGetUserSessions:
@pytest.mark.asyncio
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
"""Test getting only active user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
active = UserSession(
@@ -115,8 +117,8 @@ class TestGetUserSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
inactive = UserSession(
user_id=async_test_user.id,
@@ -125,17 +127,15 @@ class TestGetUserSessions:
ip_address="192.168.1.2",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add_all([active, inactive])
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id),
active_only=True
session, user_id=str(async_test_user.id), active_only=True
)
assert len(results) == 1
assert results[0].is_active is True
@@ -143,7 +143,7 @@ class TestGetUserSessions:
@pytest.mark.asyncio
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
"""Test getting all user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
for i in range(3):
@@ -154,17 +154,15 @@ class TestGetUserSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=i % 2 == 0,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id),
active_only=False
session, user_id=str(async_test_user.id), active_only=False
)
assert len(results) == 3
@@ -175,7 +173,7 @@ class TestCreateSession:
@pytest.mark.asyncio
async def test_create_session_success(self, async_test_db, async_test_user):
"""Test successfully creating a session_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
session_data = SessionCreate(
@@ -185,10 +183,10 @@ class TestCreateSession:
device_id="device_123",
ip_address="192.168.1.100",
user_agent="Mozilla/5.0",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=7),
location_city="San Francisco",
location_country="USA"
location_country="USA",
)
result = await session_crud.create_session(session, obj_in=session_data)
@@ -204,7 +202,7 @@ class TestDeactivate:
@pytest.mark.asyncio
async def test_deactivate_success(self, async_test_db, async_test_user):
"""Test successfully deactivating a session_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -214,8 +212,8 @@ class TestDeactivate:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -229,7 +227,7 @@ class TestDeactivate:
@pytest.mark.asyncio
async def test_deactivate_not_found(self, async_test_db):
"""Test deactivating non-existent session returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(uuid4()))
@@ -240,9 +238,11 @@ class TestDeactivateAllUserSessions:
"""Tests for deactivate_all_user_sessions method."""
@pytest.mark.asyncio
async def test_deactivate_all_user_sessions_success(self, async_test_db, async_test_user):
async def test_deactivate_all_user_sessions_success(
self, async_test_db, async_test_user
):
"""Test deactivating all user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create minimal sessions for test (2 instead of 5)
@@ -254,16 +254,15 @@ class TestDeactivateAllUserSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.deactivate_all_user_sessions(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 2
@@ -274,7 +273,7 @@ class TestUpdateLastUsed:
@pytest.mark.asyncio
async def test_update_last_used_success(self, async_test_db, async_test_user):
"""Test updating last_used_at timestamp."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -284,8 +283,8 @@ class TestUpdateLastUsed:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
@@ -303,7 +302,7 @@ class TestGetUserSessionCount:
@pytest.mark.asyncio
async def test_get_user_session_count_success(self, async_test_db, async_test_user):
"""Test getting user session count."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
for i in range(3):
@@ -314,28 +313,26 @@ class TestGetUserSessionCount:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 3
@pytest.mark.asyncio
async def test_get_user_session_count_empty(self, async_test_db):
"""Test getting session count for user with no sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
session,
user_id=str(uuid4())
session, user_id=str(uuid4())
)
assert count == 0
@@ -346,7 +343,7 @@ class TestUpdateRefreshToken:
@pytest.mark.asyncio
async def test_update_refresh_token_success(self, async_test_db, async_test_user):
"""Test updating refresh token JTI and expiration."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -356,26 +353,34 @@ class TestUpdateRefreshToken:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
await session.refresh(user_session)
new_jti = "new_jti_123"
new_expires = datetime.now(timezone.utc) + timedelta(days=14)
new_expires = datetime.now(UTC) + timedelta(days=14)
result = await session_crud.update_refresh_token(
session,
session=user_session,
new_jti=new_jti,
new_expires_at=new_expires
new_expires_at=new_expires,
)
assert result.refresh_token_jti == new_jti
# Compare timestamps ignoring timezone info
assert abs((result.expires_at.replace(tzinfo=None) - new_expires.replace(tzinfo=None)).total_seconds()) < 1
assert (
abs(
(
result.expires_at.replace(tzinfo=None)
- new_expires.replace(tzinfo=None)
).total_seconds()
)
< 1
)
class TestCleanupExpired:
@@ -384,7 +389,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio
async def test_cleanup_expired_success(self, async_test_db, async_test_user):
"""Test cleaning up old expired inactive sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired inactive session
async with AsyncTestingSessionLocal() as session:
@@ -395,9 +400,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
created_at=datetime.now(timezone.utc) - timedelta(days=35)
expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(UTC) - timedelta(days=35),
)
session.add(old_session)
await session.commit()
@@ -410,7 +415,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio
async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user):
"""Test that cleanup keeps recent expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create recent expired inactive session (less than keep_days old)
async with AsyncTestingSessionLocal() as session:
@@ -421,9 +426,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2),
created_at=datetime.now(timezone.utc) - timedelta(days=1)
expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2),
created_at=datetime.now(UTC) - timedelta(days=1),
)
session.add(recent_session)
await session.commit()
@@ -436,7 +441,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio
async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user):
"""Test that cleanup does not delete active sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired but ACTIVE session
async with AsyncTestingSessionLocal() as session:
@@ -447,9 +452,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True, # Active
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
created_at=datetime.now(timezone.utc) - timedelta(days=35)
expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(UTC) - timedelta(days=35),
)
session.add(active_session)
await session.commit()
@@ -464,9 +469,11 @@ class TestCleanupExpiredForUser:
"""Tests for cleanup_expired_for_user method."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_success(self, async_test_db, async_test_user):
async def test_cleanup_expired_for_user_success(
self, async_test_db, async_test_user
):
"""Test cleaning up expired sessions for specific user."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired inactive session for user
async with AsyncTestingSessionLocal() as session:
@@ -477,8 +484,8 @@ class TestCleanupExpiredForUser:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
session.add(expired_session)
await session.commit()
@@ -486,27 +493,27 @@ class TestCleanupExpiredForUser:
# Cleanup for user
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 1
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db):
"""Test cleanup with invalid user UUID."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError, match="Invalid user ID format"):
await session_crud.cleanup_expired_for_user(
session,
user_id="not-a-valid-uuid"
session, user_id="not-a-valid-uuid"
)
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_keeps_active(self, async_test_db, async_test_user):
async def test_cleanup_expired_for_user_keeps_active(
self, async_test_db, async_test_user
):
"""Test that cleanup for user keeps active sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired but active session
async with AsyncTestingSessionLocal() as session:
@@ -517,8 +524,8 @@ class TestCleanupExpiredForUser:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True, # Active
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
session.add(active_session)
await session.commit()
@@ -526,8 +533,7 @@ class TestCleanupExpiredForUser:
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 0 # Should not delete active sessions
@@ -536,9 +542,11 @@ class TestGetUserSessionsWithUser:
"""Tests for get_user_sessions with eager loading."""
@pytest.mark.asyncio
async def test_get_user_sessions_with_user_relationship(self, async_test_db, async_test_user):
async def test_get_user_sessions_with_user_relationship(
self, async_test_db, async_test_user
):
"""Test getting sessions with user relationship loaded."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -548,8 +556,8 @@ class TestGetUserSessionsWithUser:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -557,8 +565,6 @@ class TestGetUserSessionsWithUser:
# Get with user relationship
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id),
with_user=True
session, user_id=str(async_test_user.id), with_user=True
)
assert len(results) >= 1

View File

@@ -2,12 +2,14 @@
"""
Comprehensive tests for session CRUD database failure scenarios.
"""
import pytest
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch
from sqlalchemy.exc import OperationalError, IntegrityError
from datetime import datetime, timedelta, timezone
from uuid import uuid4
import pytest
from sqlalchemy.exc import OperationalError
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate
@@ -19,13 +21,14 @@ class TestSessionCRUDGetByJtiFailures:
@pytest.mark.asyncio
async def test_get_by_jti_database_error(self, async_test_db):
"""Test get_by_jti handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("DB connection lost", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_by_jti(session, jti="test_jti")
@@ -36,13 +39,14 @@ class TestSessionCRUDGetActiveByJtiFailures:
@pytest.mark.asyncio
async def test_get_active_by_jti_database_error(self, async_test_db):
"""Test get_active_by_jti handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Query timeout", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_active_by_jti(session, jti="test_jti")
@@ -51,19 +55,21 @@ class TestSessionCRUDGetUserSessionsFailures:
"""Test get_user_sessions exception handling."""
@pytest.mark.asyncio
async def test_get_user_sessions_database_error(self, async_test_db, async_test_user):
async def test_get_user_sessions_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_sessions handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Database error", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
@@ -71,24 +77,29 @@ class TestSessionCRUDCreateSessionFailures:
"""Test create_session exception handling."""
@pytest.mark.asyncio
async def test_create_session_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_create_session_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles commit failures with rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Commit failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
with pytest.raises(ValueError, match="Failed to create session"):
@@ -97,24 +108,29 @@ class TestSessionCRUDCreateSessionFailures:
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_create_session_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
async def test_create_session_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles unexpected errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Unexpected error")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
with pytest.raises(ValueError, match="Failed to create session"):
@@ -127,9 +143,11 @@ class TestSessionCRUDDeactivateFailures:
"""Test deactivate exception handling."""
@pytest.mark.asyncio
async def test_deactivate_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_deactivate_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session first
async with SessionLocal() as session:
@@ -140,8 +158,8 @@ class TestSessionCRUDDeactivateFailures:
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -150,13 +168,18 @@ class TestSessionCRUDDeactivateFailures:
# Test deactivate failure
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Deactivate failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate(session, session_id=str(session_id))
await session_crud.deactivate(
session, session_id=str(session_id)
)
mock_rollback.assert_called_once()
@@ -165,20 +188,24 @@ class TestSessionCRUDDeactivateAllFailures:
"""Test deactivate_all_user_sessions exception handling."""
@pytest.mark.asyncio
async def test_deactivate_all_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_deactivate_all_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate_all handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Bulk deactivate failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate_all_user_sessions(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
@@ -188,9 +215,11 @@ class TestSessionCRUDUpdateLastUsedFailures:
"""Test update_last_used exception handling."""
@pytest.mark.asyncio
async def test_update_last_used_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_update_last_used_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_last_used handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session
async with SessionLocal() as session:
@@ -201,8 +230,8 @@ class TestSessionCRUDUpdateLastUsedFailures:
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
@@ -211,15 +240,19 @@ class TestSessionCRUDUpdateLastUsedFailures:
# Test update failure
async with SessionLocal() as session:
from sqlalchemy import select
from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one()
async def mock_commit():
raise OperationalError("Update failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_last_used(session, session=sess)
@@ -230,9 +263,11 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
"""Test update_refresh_token exception handling."""
@pytest.mark.asyncio
async def test_update_refresh_token_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_update_refresh_token_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_refresh_token handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session
async with SessionLocal() as session:
@@ -243,8 +278,8 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -253,21 +288,25 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
# Test update failure
async with SessionLocal() as session:
from sqlalchemy import select
from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one()
async def mock_commit():
raise OperationalError("Token update failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_refresh_token(
session,
session=sess,
new_jti=str(uuid4()),
new_expires_at=datetime.now(timezone.utc) + timedelta(days=14)
new_expires_at=datetime.now(UTC) + timedelta(days=14),
)
mock_rollback.assert_called_once()
@@ -277,16 +316,21 @@ class TestSessionCRUDCleanupExpiredFailures:
"""Test cleanup_expired exception handling."""
@pytest.mark.asyncio
async def test_cleanup_expired_commit_failure_triggers_rollback(self, async_test_db):
async def test_cleanup_expired_commit_failure_triggers_rollback(
self, async_test_db
):
"""Test cleanup_expired handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Cleanup failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired(session, keep_days=30)
@@ -297,20 +341,24 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
"""Test cleanup_expired_for_user exception handling."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test cleanup_expired_for_user handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("User cleanup failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired_for_user(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
@@ -320,17 +368,19 @@ class TestSessionCRUDGetUserSessionCountFailures:
"""Test get_user_session_count exception handling."""
@pytest.mark.asyncio
async def test_get_user_session_count_database_error(self, async_test_db, async_test_user):
async def test_get_user_session_count_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_session_count handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Count query failed", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_session_count(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)

View File

@@ -2,12 +2,10 @@
"""
Comprehensive tests for async user CRUD operations.
"""
import pytest
from datetime import datetime, timezone
from uuid import uuid4
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
@@ -17,7 +15,7 @@ class TestGetByEmail:
@pytest.mark.asyncio
async def test_get_by_email_success(self, async_test_db, async_test_user):
"""Test getting user by email."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email=async_test_user.email)
@@ -28,10 +26,12 @@ class TestGetByEmail:
@pytest.mark.asyncio
async def test_get_by_email_not_found(self, async_test_db):
"""Test getting non-existent email returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email="nonexistent@example.com")
result = await user_crud.get_by_email(
session, email="nonexistent@example.com"
)
assert result is None
@@ -41,7 +41,7 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_user_success(self, async_test_db):
"""Test successfully creating a user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
@@ -49,7 +49,7 @@ class TestCreate:
password="SecurePass123!",
first_name="New",
last_name="User",
phone_number="+1234567890"
phone_number="+1234567890",
)
result = await user_crud.create(session, obj_in=user_data)
@@ -65,7 +65,7 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_superuser_success(self, async_test_db):
"""Test creating a superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
@@ -73,7 +73,7 @@ class TestCreate:
password="SuperPass123!",
first_name="Super",
last_name="User",
is_superuser=True
is_superuser=True,
)
result = await user_crud.create(session, obj_in=user_data)
@@ -83,14 +83,14 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
"""Test creating user with duplicate email raises ValueError."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email=async_test_user.email, # Duplicate email
password="AnotherPass123!",
first_name="Duplicate",
last_name="User"
last_name="User",
)
with pytest.raises(ValueError) as exc_info:
@@ -105,16 +105,14 @@ class TestUpdate:
@pytest.mark.asyncio
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
"""Test updating basic user fields."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get fresh copy of user
user = await user_crud.get(session, id=str(async_test_user.id))
update_data = UserUpdate(
first_name="Updated",
last_name="Name",
phone_number="+9876543210"
first_name="Updated", last_name="Name", phone_number="+9876543210"
)
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
@@ -125,7 +123,7 @@ class TestUpdate:
@pytest.mark.asyncio
async def test_update_user_password(self, async_test_db):
"""Test updating user password."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a fresh user for this test
async with AsyncTestingSessionLocal() as session:
@@ -133,7 +131,7 @@ class TestUpdate:
email="passwordtest@example.com",
password="OldPassword123!",
first_name="Pass",
last_name="Test"
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -149,12 +147,14 @@ class TestUpdate:
await session.refresh(result)
assert result.password_hash != old_password_hash
assert result.password_hash is not None
assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed
assert (
"NewDifferentPassword123!" not in result.password_hash
) # Should be hashed
@pytest.mark.asyncio
async def test_update_user_with_dict(self, async_test_db, async_test_user):
"""Test updating user with dictionary."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -171,13 +171,11 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
"""Test basic pagination."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10
session, skip=0, limit=10
)
assert total >= 1
assert len(users) >= 1
@@ -186,7 +184,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
"""Test sorting in ascending order."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -195,17 +193,13 @@ class TestGetMultiWithTotal:
email=f"sort{i}@example.com",
password="SecurePass123!",
first_name=f"User{i}",
last_name="Test"
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="email",
sort_order="asc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="asc"
)
# Check if sorted (at least the test users)
@@ -216,7 +210,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
"""Test sorting in descending order."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -225,17 +219,13 @@ class TestGetMultiWithTotal:
email=f"desc{i}@example.com",
password="SecurePass123!",
first_name=f"User{i}",
last_name="Test"
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="email",
sort_order="desc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="desc"
)
# Check if sorted descending (at least the test users)
@@ -246,7 +236,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_filtering(self, async_test_db):
"""Test filtering by field."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create active and inactive users
async with AsyncTestingSessionLocal() as session:
@@ -254,7 +244,7 @@ class TestGetMultiWithTotal:
email="active@example.com",
password="SecurePass123!",
first_name="Active",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=active_user)
@@ -262,23 +252,18 @@ class TestGetMultiWithTotal:
email="inactive@example.com",
password="SecurePass123!",
first_name="Inactive",
last_name="User"
last_name="User",
)
created_inactive = await user_crud.create(session, obj_in=inactive_user)
# Deactivate the user
await user_crud.update(
session,
db_obj=created_inactive,
obj_in={"is_active": False}
session, db_obj=created_inactive, obj_in={"is_active": False}
)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=100,
filters={"is_active": True}
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=100, filters={"is_active": True}
)
# All returned users should be active
@@ -287,7 +272,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_search(self, async_test_db):
"""Test search functionality."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user with unique name
async with AsyncTestingSessionLocal() as session:
@@ -295,16 +280,13 @@ class TestGetMultiWithTotal:
email="searchable@example.com",
password="SecurePass123!",
first_name="Searchable",
last_name="UserName"
last_name="UserName",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=100,
search="Searchable"
session, skip=0, limit=100, search="Searchable"
)
assert total >= 1
@@ -313,7 +295,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_pagination(self, async_test_db):
"""Test pagination with skip and limit."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -322,23 +304,19 @@ class TestGetMultiWithTotal:
email=f"page{i}@example.com",
password="SecurePass123!",
first_name=f"Page{i}",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
# Get first page
users_page1, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=2
session, skip=0, limit=2
)
# Get second page
users_page2, total2 = await user_crud.get_multi_with_total(
session,
skip=2,
limit=2
session, skip=2, limit=2
)
# Total should be same
@@ -349,7 +327,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
"""Test validation fails for negative skip."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
@@ -360,7 +338,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
"""Test validation fails for negative limit."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
@@ -371,7 +349,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
"""Test validation fails for limit > 1000."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
@@ -386,7 +364,7 @@ class TestBulkUpdateStatus:
@pytest.mark.asyncio
async def test_bulk_update_status_success(self, async_test_db):
"""Test bulk updating user status."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
@@ -396,7 +374,7 @@ class TestBulkUpdateStatus:
email=f"bulk{i}@example.com",
password="SecurePass123!",
first_name=f"Bulk{i}",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
@@ -404,9 +382,7 @@ class TestBulkUpdateStatus:
# Bulk deactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session,
user_ids=user_ids,
is_active=False
session, user_ids=user_ids, is_active=False
)
assert count == 3
@@ -419,20 +395,18 @@ class TestBulkUpdateStatus:
@pytest.mark.asyncio
async def test_bulk_update_status_empty_list(self, async_test_db):
"""Test bulk update with empty list returns 0."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session,
user_ids=[],
is_active=False
session, user_ids=[], is_active=False
)
assert count == 0
@pytest.mark.asyncio
async def test_bulk_update_status_reactivate(self, async_test_db):
"""Test bulk reactivating users."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create inactive user
async with AsyncTestingSessionLocal() as session:
@@ -440,7 +414,7 @@ class TestBulkUpdateStatus:
email="reactivate@example.com",
password="SecurePass123!",
first_name="Reactivate",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
# Deactivate
@@ -450,9 +424,7 @@ class TestBulkUpdateStatus:
# Reactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session,
user_ids=[user_id],
is_active=True
session, user_ids=[user_id], is_active=True
)
assert count == 1
@@ -468,7 +440,7 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio
async def test_bulk_soft_delete_success(self, async_test_db):
"""Test bulk soft deleting users."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
@@ -478,17 +450,14 @@ class TestBulkSoftDelete:
email=f"delete{i}@example.com",
password="SecurePass123!",
first_name=f"Delete{i}",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
# Bulk delete
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=user_ids
)
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
assert count == 3
# Verify all are soft deleted
@@ -501,7 +470,7 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
"""Test bulk soft delete with excluded user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
@@ -511,7 +480,7 @@ class TestBulkSoftDelete:
email=f"exclude{i}@example.com",
password="SecurePass123!",
first_name=f"Exclude{i}",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
@@ -520,9 +489,7 @@ class TestBulkSoftDelete:
exclude_id = user_ids[0]
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=user_ids,
exclude_user_id=exclude_id
session, user_ids=user_ids, exclude_user_id=exclude_id
)
assert count == 2 # Only 2 deleted
@@ -534,19 +501,16 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio
async def test_bulk_soft_delete_empty_list(self, async_test_db):
"""Test bulk delete with empty list returns 0."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=[]
)
count = await user_crud.bulk_soft_delete(session, user_ids=[])
assert count == 0
@pytest.mark.asyncio
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
"""Test bulk delete where all users are excluded."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user
async with AsyncTestingSessionLocal() as session:
@@ -554,7 +518,7 @@ class TestBulkSoftDelete:
email="onlyuser@example.com",
password="SecurePass123!",
first_name="Only",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -562,16 +526,14 @@ class TestBulkSoftDelete:
# Try to delete but exclude
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=[user_id],
exclude_user_id=user_id
session, user_ids=[user_id], exclude_user_id=user_id
)
assert count == 0
@pytest.mark.asyncio
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
"""Test bulk delete doesn't re-delete already deleted users."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create and delete user
async with AsyncTestingSessionLocal() as session:
@@ -579,7 +541,7 @@ class TestBulkSoftDelete:
email="predeleted@example.com",
password="SecurePass123!",
first_name="PreDeleted",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -589,10 +551,7 @@ class TestBulkSoftDelete:
# Try to delete again
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=[user_id]
)
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
assert count == 0 # Already deleted
@@ -602,7 +561,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_active_true(self, async_test_db, async_test_user):
"""Test is_active returns True for active user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -611,14 +570,14 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_active_false(self, async_test_db):
"""Test is_active returns False for inactive user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="inactive2@example.com",
password="SecurePass123!",
first_name="Inactive",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
@@ -628,7 +587,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
"""Test is_superuser returns True for superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_superuser.id))
@@ -637,7 +596,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_superuser_false(self, async_test_db, async_test_user):
"""Test is_superuser returns False for regular user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -654,42 +613,52 @@ class TestUserExceptionHandlers:
async def test_get_by_email_database_error(self, async_test_db):
"""Test get_by_email handles database errors (covers lines 30-32)."""
from unittest.mock import patch
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("Database query failed")):
with patch.object(
session, "execute", side_effect=Exception("Database query failed")
):
with pytest.raises(Exception, match="Database query failed"):
await user_crud.get_by_email(session, email="test@example.com")
@pytest.mark.asyncio
async def test_bulk_update_status_database_error(self, async_test_db, async_test_user):
async def test_bulk_update_status_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_update_status handles database errors (covers lines 205-208)."""
from unittest.mock import patch, AsyncMock
test_engine, AsyncTestingSessionLocal = async_test_db
from unittest.mock import AsyncMock, patch
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock execute to fail
with patch.object(session, 'execute', side_effect=Exception("Bulk update failed")):
with patch.object(session, 'rollback', new_callable=AsyncMock):
with patch.object(
session, "execute", side_effect=Exception("Bulk update failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk update failed"):
await user_crud.bulk_update_status(
session,
user_ids=[async_test_user.id],
is_active=False
session, user_ids=[async_test_user.id], is_active=False
)
@pytest.mark.asyncio
async def test_bulk_soft_delete_database_error(self, async_test_db, async_test_user):
async def test_bulk_soft_delete_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_soft_delete handles database errors (covers lines 257-260)."""
from unittest.mock import patch, AsyncMock
test_engine, AsyncTestingSessionLocal = async_test_db
from unittest.mock import AsyncMock, patch
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock execute to fail
with patch.object(session, 'execute', side_effect=Exception("Bulk delete failed")):
with patch.object(session, 'rollback', new_callable=AsyncMock):
with patch.object(
session, "execute", side_effect=Exception("Bulk delete failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk delete failed"):
await user_crud.bulk_soft_delete(
session,
user_ids=[async_test_user.id]
session, user_ids=[async_test_user.id]
)

View File

@@ -1,8 +1,10 @@
# tests/models/test_user.py
import uuid
import pytest
from datetime import datetime
import pytest
from sqlalchemy.exc import IntegrityError
from app.models.user import User
@@ -166,7 +168,6 @@ def test_user_required_fields(db_session):
db_session.rollback()
def test_user_defaults(db_session):
"""Test that default values are correctly set."""
# Arrange - Create a minimal user with only required fields
@@ -210,22 +211,13 @@ def test_user_with_complex_json_preferences(db_session):
"""Test storing and retrieving complex JSON preferences."""
# Arrange - Create a user with nested JSON preferences
complex_preferences = {
"theme": {
"mode": "dark",
"colors": {
"primary": "#333",
"secondary": "#666"
}
},
"theme": {"mode": "dark", "colors": {"primary": "#333", "secondary": "#666"}},
"notifications": {
"email": True,
"sms": False,
"push": {
"enabled": True,
"quiet_hours": [22, 7]
}
"push": {"enabled": True, "quiet_hours": [22, 7]},
},
"tags": ["important", "family", "events"]
"tags": ["important", "family", "events"],
}
user = User(
@@ -234,16 +226,18 @@ def test_user_with_complex_json_preferences(db_session):
password_hash="hashedpassword",
first_name="Complex",
last_name="JSON",
preferences=complex_preferences
preferences=complex_preferences,
)
db_session.add(user)
db_session.commit()
# Act - Retrieve the user
retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first()
retrieved_user = (
db_session.query(User).filter_by(email="complex@example.com").first()
)
# Assert - The complex JSON should be preserved
assert retrieved_user.preferences == complex_preferences
assert retrieved_user.preferences["theme"]["colors"]["primary"] == "#333"
assert retrieved_user.preferences["notifications"]["push"]["quiet_hours"] == [22, 7]
assert "important" in retrieved_user.preferences["tags"]
assert "important" in retrieved_user.preferences["tags"]

View File

@@ -5,6 +5,7 @@ Covers Pydantic validators for:
- Slug validation (lines 26, 28, 30, 32, 62-70)
- Name validation (lines 40, 77)
"""
import pytest
from pydantic import ValidationError
@@ -20,19 +21,13 @@ class TestOrganizationBaseValidators:
def test_valid_organization_base(self):
"""Test that valid data passes validation."""
org = OrganizationBase(
name="Test Organization",
slug="test-org"
)
org = OrganizationBase(name="Test Organization", slug="test-org")
assert org.name == "Test Organization"
assert org.slug == "test-org"
def test_slug_none_returns_none(self):
"""Test that None slug is allowed (covers line 26)."""
org = OrganizationBase(
name="Test Organization",
slug=None
)
org = OrganizationBase(name="Test Organization", slug=None)
assert org.slug is None
def test_slug_invalid_characters_rejected(self):
@@ -40,57 +35,46 @@ class TestOrganizationBaseValidators:
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name="Test Organization",
slug="Test_Org!" # Uppercase and special chars
slug="Test_Org!", # Uppercase and special chars
)
errors = exc_info.value.errors()
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
assert any(
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
)
def test_slug_starts_with_hyphen_rejected(self):
"""Test slug starting with hyphen is rejected (covers line 30)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name="Test Organization",
slug="-test-org"
)
OrganizationBase(name="Test Organization", slug="-test-org")
errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_slug_ends_with_hyphen_rejected(self):
"""Test slug ending with hyphen is rejected (covers line 30)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name="Test Organization",
slug="test-org-"
)
OrganizationBase(name="Test Organization", slug="test-org-")
errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_slug_consecutive_hyphens_rejected(self):
"""Test slug with consecutive hyphens is rejected (covers line 32)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name="Test Organization",
slug="test--org"
)
OrganizationBase(name="Test Organization", slug="test--org")
errors = exc_info.value.errors()
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors)
assert any(
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
)
def test_name_whitespace_only_rejected(self):
"""Test whitespace-only name is rejected (covers line 40)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name=" ",
slug="test-org"
)
OrganizationBase(name=" ", slug="test-org")
errors = exc_info.value.errors()
assert any("name cannot be empty" in str(e['msg']) for e in errors)
assert any("name cannot be empty" in str(e["msg"]) for e in errors)
def test_name_trimmed(self):
"""Test that name is trimmed."""
org = OrganizationBase(
name=" Test Organization ",
slug="test-org"
)
org = OrganizationBase(name=" Test Organization ", slug="test-org")
assert org.name == "Test Organization"
@@ -99,22 +83,18 @@ class TestOrganizationCreateValidators:
def test_valid_organization_create(self):
"""Test that valid data passes validation."""
org = OrganizationCreate(
name="Test Organization",
slug="test-org"
)
org = OrganizationCreate(name="Test Organization", slug="test-org")
assert org.name == "Test Organization"
assert org.slug == "test-org"
def test_slug_validation_inherited(self):
"""Test that slug validation is inherited from base."""
with pytest.raises(ValidationError) as exc_info:
OrganizationCreate(
name="Test",
slug="Invalid_Slug!"
)
OrganizationCreate(name="Test", slug="Invalid_Slug!")
errors = exc_info.value.errors()
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
assert any(
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
)
class TestOrganizationUpdateValidators:
@@ -122,10 +102,7 @@ class TestOrganizationUpdateValidators:
def test_valid_organization_update(self):
"""Test that valid update data passes validation."""
org = OrganizationUpdate(
name="Updated Name",
slug="updated-slug"
)
org = OrganizationUpdate(name="Updated Name", slug="updated-slug")
assert org.name == "Updated Name"
assert org.slug == "updated-slug"
@@ -139,35 +116,39 @@ class TestOrganizationUpdateValidators:
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="Test_Org!")
errors = exc_info.value.errors()
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
assert any(
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
)
def test_update_slug_starts_with_hyphen_rejected(self):
"""Test update slug starting with hyphen is rejected (covers line 66)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="-test-org")
errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_update_slug_ends_with_hyphen_rejected(self):
"""Test update slug ending with hyphen is rejected (covers line 66)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="test-org-")
errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_update_slug_consecutive_hyphens_rejected(self):
"""Test update slug with consecutive hyphens is rejected (covers line 68)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="test--org")
errors = exc_info.value.errors()
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors)
assert any(
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
)
def test_update_name_whitespace_only_rejected(self):
"""Test whitespace-only name in update is rejected (covers line 77)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(name=" ")
errors = exc_info.value.errors()
assert any("name cannot be empty" in str(e['msg']) for e in errors)
assert any("name cannot be empty" in str(e["msg"]) for e in errors)
def test_update_name_none_allowed(self):
"""Test that None name is allowed in update."""

View File

@@ -1,80 +1,177 @@
# tests/schemas/test_user_schemas.py
import pytest
import re
import pytest
from pydantic import ValidationError
from app.schemas.users import UserBase, UserCreate
class TestPhoneNumberValidation:
"""Tests for phone number validation in user schemas"""
def test_valid_swiss_numbers(self):
"""Test valid Swiss phone numbers are accepted"""
# International format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41791234567",
)
assert user.phone_number == "+41791234567"
# Local format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0791234567",
)
assert user.phone_number == "0791234567"
# With formatting characters
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41 79 123 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="079 123 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41-79-123-45-67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="079-123-45-67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41 (79) 123 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="079 (123) 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
def test_valid_italian_numbers(self):
"""Test valid Italian phone numbers are accepted"""
# International format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+393451234567",
)
assert user.phone_number == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39345123456",
)
assert user.phone_number == "+39345123456"
# Local format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="03451234567",
)
assert user.phone_number == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345123456789",
)
assert user.phone_number == "0345123456789"
# With formatting characters
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39 345 123 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345 123 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39-345-123-4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345-123-4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39 (345) 123 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345 (123) 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
def test_none_phone_number(self):
"""Test that None is accepted as a valid value (optional phone number)"""
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None)
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number=None,
)
assert user.phone_number is None
def test_invalid_phone_numbers(self):
@@ -83,17 +180,14 @@ class TestPhoneNumberValidation:
# Too short
"+12",
"012",
# Invalid characters
"+41xyz123456",
"079abc4567",
"123-abc-7890",
"+1(800)CALL-NOW",
# Completely invalid formats
"++4412345678", # Double plus
# Note: "()+41123456" becomes "+41123456" after cleaning, which is valid
# Empty string
"",
# Spaces only
@@ -102,7 +196,12 @@ class TestPhoneNumberValidation:
for number in invalid_numbers:
with pytest.raises(ValidationError):
UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number)
UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number=number,
)
def test_phone_validation_in_user_create(self):
"""Test that phone validation also works in UserCreate schema"""
@@ -112,7 +211,7 @@ class TestPhoneNumberValidation:
first_name="Test",
last_name="User",
password="Password123!",
phone_number="+41791234567"
phone_number="+41791234567",
)
assert user.phone_number == "+41791234567"
@@ -123,5 +222,5 @@ class TestPhoneNumberValidation:
first_name="Test",
last_name="User",
password="Password123!",
phone_number="invalid-number"
)
phone_number="invalid-number",
)

View File

@@ -7,12 +7,13 @@ Covers all edge cases in validation functions:
- validate_email_format (line 148)
- validate_slug (lines 170-183)
"""
import pytest
from app.schemas.validators import (
validate_email_format,
validate_password_strength,
validate_phone_number,
validate_email_format,
validate_slug,
)
@@ -108,12 +109,14 @@ class TestPhoneNumberValidator:
validate_phone_number("+123456789012345") # 15 digits after +
def test_multiple_plus_symbols_rejected(self):
"""Test phone number with multiple + symbols.
r"""Test phone number with multiple + symbols.
Note: Line 115 is defensive code - the regex check at line 110 catches this first.
The regex ^(?:\+[0-9]{8,14}|0[0-9]{8,14})$ only allows + at the start.
"""
with pytest.raises(ValueError, match="must start with \\+ or 0 followed by 8-14 digits"):
with pytest.raises(
ValueError, match="must start with \\+ or 0 followed by 8-14 digits"
):
validate_phone_number("+1234+5678901")
def test_non_digit_after_prefix_rejected(self):

View File

@@ -1,14 +1,18 @@
# tests/services/test_auth_service.py
import uuid
import pytest
import pytest_asyncio
from unittest.mock import patch
import pytest
from sqlalchemy import select
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
from app.core.auth import (
TokenInvalidError,
get_password_hash,
verify_password,
)
from app.models.user import User
from app.schemas.users import UserCreate, Token
from app.services.auth_service import AuthService, AuthenticationError
from app.schemas.users import Token, UserCreate
from app.services.auth_service import AuthenticationError, AuthService
class TestAuthServiceAuthentication:
@@ -17,12 +21,14 @@ class TestAuthServiceAuthentication:
@pytest.mark.asyncio
async def test_authenticate_valid_user(self, async_test_db, async_test_user):
"""Test authenticating a user with valid credentials"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(password)
await session.commit()
@@ -30,9 +36,7 @@ class TestAuthServiceAuthentication:
# Authenticate with correct credentials
async with AsyncTestingSessionLocal() as session:
auth_user = await AuthService.authenticate_user(
db=session,
email=async_test_user.email,
password=password
db=session, email=async_test_user.email, password=password
)
assert auth_user is not None
@@ -42,26 +46,28 @@ class TestAuthServiceAuthentication:
@pytest.mark.asyncio
async def test_authenticate_nonexistent_user(self, async_test_db):
"""Test authenticating with an email that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await AuthService.authenticate_user(
db=session,
email="nonexistent@example.com",
password="password"
db=session, email="nonexistent@example.com", password="password"
)
assert user is None
@pytest.mark.asyncio
async def test_authenticate_with_wrong_password(self, async_test_db, async_test_user):
async def test_authenticate_with_wrong_password(
self, async_test_db, async_test_user
):
"""Test authenticating with the wrong password"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(password)
await session.commit()
@@ -69,9 +75,7 @@ class TestAuthServiceAuthentication:
# Authenticate with wrong password
async with AsyncTestingSessionLocal() as session:
auth_user = await AuthService.authenticate_user(
db=session,
email=async_test_user.email,
password="WrongPassword123"
db=session, email=async_test_user.email, password="WrongPassword123"
)
assert auth_user is None
@@ -79,12 +83,14 @@ class TestAuthServiceAuthentication:
@pytest.mark.asyncio
async def test_authenticate_inactive_user(self, async_test_db, async_test_user):
"""Test authenticating an inactive user"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password and make user inactive
password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(password)
user.is_active = False
@@ -94,9 +100,7 @@ class TestAuthServiceAuthentication:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(AuthenticationError):
await AuthService.authenticate_user(
db=session,
email=async_test_user.email,
password=password
db=session, email=async_test_user.email, password=password
)
@@ -106,14 +110,14 @@ class TestAuthServiceUserCreation:
@pytest.mark.asyncio
async def test_create_new_user(self, async_test_db):
"""Test creating a new user"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
user_data = UserCreate(
email="newuser@example.com",
password="TestPassword123!",
first_name="New",
last_name="User",
phone_number="+1234567890"
phone_number="+1234567890",
)
async with AsyncTestingSessionLocal() as session:
@@ -135,15 +139,17 @@ class TestAuthServiceUserCreation:
assert user.is_superuser is False
@pytest.mark.asyncio
async def test_create_user_with_existing_email(self, async_test_db, async_test_user):
async def test_create_user_with_existing_email(
self, async_test_db, async_test_user
):
"""Test creating a user with an email that already exists"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
user_data = UserCreate(
email=async_test_user.email, # Use existing email
password="TestPassword123!",
first_name="Duplicate",
last_name="User"
last_name="User",
)
# Should raise AuthenticationError
@@ -169,7 +175,7 @@ class TestAuthServiceTokens:
@pytest.mark.asyncio
async def test_refresh_tokens(self, async_test_db, async_test_user):
"""Test refreshing tokens with a valid refresh token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create initial tokens
initial_tokens = AuthService.create_tokens(async_test_user)
@@ -177,8 +183,7 @@ class TestAuthServiceTokens:
# Refresh tokens
async with AsyncTestingSessionLocal() as session:
new_tokens = await AuthService.refresh_tokens(
db=session,
refresh_token=initial_tokens.refresh_token
db=session, refresh_token=initial_tokens.refresh_token
)
# Verify new tokens are different from old ones
@@ -188,7 +193,7 @@ class TestAuthServiceTokens:
@pytest.mark.asyncio
async def test_refresh_tokens_with_invalid_token(self, async_test_db):
"""Test refreshing tokens with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create an invalid token
invalid_token = "invalid.token.string"
@@ -197,14 +202,15 @@ class TestAuthServiceTokens:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(TokenInvalidError):
await AuthService.refresh_tokens(
db=session,
refresh_token=invalid_token
db=session, refresh_token=invalid_token
)
@pytest.mark.asyncio
async def test_refresh_tokens_with_access_token(self, async_test_db, async_test_user):
async def test_refresh_tokens_with_access_token(
self, async_test_db, async_test_user
):
"""Test refreshing tokens with an access token instead of refresh token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create tokens
tokens = AuthService.create_tokens(async_test_user)
@@ -213,18 +219,20 @@ class TestAuthServiceTokens:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(TokenInvalidError):
await AuthService.refresh_tokens(
db=session,
refresh_token=tokens.access_token
db=session, refresh_token=tokens.access_token
)
@pytest.mark.asyncio
async def test_refresh_tokens_with_nonexistent_user(self, async_test_db):
"""Test refreshing tokens for a user that doesn't exist in the database"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a token for a non-existent user
non_existent_id = str(uuid.uuid4())
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
with (
patch("app.core.auth.decode_token"),
patch("app.core.auth.get_token_data") as mock_get_data,
):
# Mock the token data to return a non-existent user ID
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
@@ -232,8 +240,7 @@ class TestAuthServiceTokens:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(TokenInvalidError):
await AuthService.refresh_tokens(
db=session,
refresh_token="some.refresh.token"
db=session, refresh_token="some.refresh.token"
)
@@ -243,12 +250,14 @@ class TestAuthServicePasswordChange:
@pytest.mark.asyncio
async def test_change_password(self, async_test_db, async_test_user):
"""Test changing a user's password"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
current_password = "CurrentPassword123"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(current_password)
await session.commit()
@@ -260,7 +269,7 @@ class TestAuthServicePasswordChange:
db=session,
user_id=async_test_user.id,
current_password=current_password,
new_password=new_password
new_password=new_password,
)
# Verify operation was successful
@@ -268,7 +277,9 @@ class TestAuthServicePasswordChange:
# Verify password was changed
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
updated_user = result.scalar_one_or_none()
# Verify old password no longer works
@@ -278,14 +289,18 @@ class TestAuthServicePasswordChange:
assert verify_password(new_password, updated_user.password_hash)
@pytest.mark.asyncio
async def test_change_password_wrong_current_password(self, async_test_db, async_test_user):
async def test_change_password_wrong_current_password(
self, async_test_db, async_test_user
):
"""Test changing password with incorrect current password"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
current_password = "CurrentPassword123"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(current_password)
await session.commit()
@@ -298,19 +313,21 @@ class TestAuthServicePasswordChange:
db=session,
user_id=async_test_user.id,
current_password=wrong_password,
new_password="NewPassword456"
new_password="NewPassword456",
)
# Verify password was not changed
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
assert verify_password(current_password, user.password_hash)
@pytest.mark.asyncio
async def test_change_password_nonexistent_user(self, async_test_db):
"""Test changing password for a user that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
non_existent_id = uuid.uuid4()
@@ -320,5 +337,5 @@ class TestAuthServicePasswordChange:
db=session,
user_id=non_existent_id,
current_password="CurrentPassword123",
new_password="NewPassword456"
new_password="NewPassword456",
)

View File

@@ -2,13 +2,15 @@
"""
Tests for email service functionality.
"""
from unittest.mock import AsyncMock
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from app.services.email_service import (
EmailService,
ConsoleEmailBackend,
SMTPEmailBackend
EmailService,
SMTPEmailBackend,
)
@@ -24,7 +26,7 @@ class TestConsoleEmailBackend:
to=["user@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>",
text_content="Test Text"
text_content="Test Text",
)
assert result is True
@@ -37,7 +39,7 @@ class TestConsoleEmailBackend:
result = await backend.send_email(
to=["user@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>"
html_content="<p>Test HTML</p>",
)
assert result is True
@@ -50,7 +52,7 @@ class TestConsoleEmailBackend:
result = await backend.send_email(
to=["user1@example.com", "user2@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>"
html_content="<p>Test HTML</p>",
)
assert result is True
@@ -66,7 +68,7 @@ class TestSMTPEmailBackend:
host="smtp.example.com",
port=587,
username="test@example.com",
password="password"
password="password",
)
assert backend.host == "smtp.example.com"
@@ -81,14 +83,14 @@ class TestSMTPEmailBackend:
host="smtp.example.com",
port=587,
username="test@example.com",
password="password"
password="password",
)
# Should fall back to console backend since SMTP is not implemented
result = await backend.send_email(
to=["user@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>"
html_content="<p>Test HTML</p>",
)
assert result is True
@@ -114,9 +116,7 @@ class TestEmailService:
service = EmailService()
result = await service.send_password_reset_email(
to_email="user@example.com",
reset_token="test_token_123",
user_name="John"
to_email="user@example.com", reset_token="test_token_123", user_name="John"
)
assert result is True
@@ -127,8 +127,7 @@ class TestEmailService:
service = EmailService()
result = await service.send_password_reset_email(
to_email="user@example.com",
reset_token="test_token_123"
to_email="user@example.com", reset_token="test_token_123"
)
assert result is True
@@ -142,8 +141,7 @@ class TestEmailService:
token = "test_reset_token_xyz"
await service.send_password_reset_email(
to_email="user@example.com",
reset_token=token
to_email="user@example.com", reset_token=token
)
# Verify send_email was called
@@ -151,7 +149,7 @@ class TestEmailService:
call_args = backend_mock.send_email.call_args
# Check that token is in the HTML content
html_content = call_args.kwargs['html_content']
html_content = call_args.kwargs["html_content"]
assert token in html_content
@pytest.mark.asyncio
@@ -162,8 +160,7 @@ class TestEmailService:
service = EmailService(backend=backend_mock)
result = await service.send_password_reset_email(
to_email="user@example.com",
reset_token="test_token"
to_email="user@example.com", reset_token="test_token"
)
assert result is False
@@ -176,7 +173,7 @@ class TestEmailService:
result = await service.send_email_verification(
to_email="user@example.com",
verification_token="verification_token_123",
user_name="Jane"
user_name="Jane",
)
assert result is True
@@ -187,8 +184,7 @@ class TestEmailService:
service = EmailService()
result = await service.send_email_verification(
to_email="user@example.com",
verification_token="verification_token_123"
to_email="user@example.com", verification_token="verification_token_123"
)
assert result is True
@@ -202,8 +198,7 @@ class TestEmailService:
token = "test_verification_token_xyz"
await service.send_email_verification(
to_email="user@example.com",
verification_token=token
to_email="user@example.com", verification_token=token
)
# Verify send_email was called
@@ -211,7 +206,7 @@ class TestEmailService:
call_args = backend_mock.send_email.call_args
# Check that token is in the HTML content
html_content = call_args.kwargs['html_content']
html_content = call_args.kwargs["html_content"]
assert token in html_content
@pytest.mark.asyncio
@@ -222,8 +217,7 @@ class TestEmailService:
service = EmailService(backend=backend_mock)
result = await service.send_email_verification(
to_email="user@example.com",
verification_token="test_token"
to_email="user@example.com", verification_token="test_token"
)
assert result is False
@@ -236,14 +230,12 @@ class TestEmailService:
service = EmailService(backend=backend_mock)
await service.send_password_reset_email(
to_email="user@example.com",
reset_token="token123",
user_name="Test User"
to_email="user@example.com", reset_token="token123", user_name="Test User"
)
call_args = backend_mock.send_email.call_args
html_content = call_args.kwargs['html_content']
text_content = call_args.kwargs['text_content']
html_content = call_args.kwargs["html_content"]
text_content = call_args.kwargs["text_content"]
# Check HTML content
assert "Password Reset" in html_content
@@ -251,7 +243,9 @@ class TestEmailService:
assert "Test User" in html_content
# Check text content
assert "Password Reset" in text_content or "password reset" in text_content.lower()
assert (
"Password Reset" in text_content or "password reset" in text_content.lower()
)
assert "token123" in text_content
@pytest.mark.asyncio
@@ -264,12 +258,12 @@ class TestEmailService:
await service.send_email_verification(
to_email="user@example.com",
verification_token="verify123",
user_name="Test User"
user_name="Test User",
)
call_args = backend_mock.send_email.call_args
html_content = call_args.kwargs['html_content']
text_content = call_args.kwargs['text_content']
html_content = call_args.kwargs["html_content"]
text_content = call_args.kwargs["text_content"]
# Check HTML content
assert "Verify" in html_content

View File

@@ -2,23 +2,27 @@
"""
Comprehensive tests for session cleanup service.
"""
import pytest
import asyncio
from datetime import datetime, timedelta, timezone
from unittest.mock import patch, MagicMock, AsyncMock
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import select
from app.models.user_session import UserSession
from sqlalchemy import select
class TestCleanupExpiredSessions:
"""Tests for cleanup_expired_sessions function."""
@pytest.mark.asyncio
async def test_cleanup_expired_sessions_success(self, async_test_db, async_test_user):
async def test_cleanup_expired_sessions_success(
self, async_test_db, async_test_user
):
"""Test successful cleanup of expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create mix of sessions
async with AsyncTestingSessionLocal() as session:
@@ -30,9 +34,9 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
created_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
created_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC),
)
# 2. Inactive, expired, old (SHOULD be deleted)
@@ -43,9 +47,9 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.2",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
created_at=datetime.now(timezone.utc) - timedelta(days=40),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=10),
created_at=datetime.now(UTC) - timedelta(days=40),
last_used_at=datetime.now(UTC),
)
# 3. Inactive, expired, recent (should NOT be deleted - within keep_days)
@@ -56,17 +60,23 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.3",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
created_at=datetime.now(timezone.utc) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=1),
created_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC),
)
session.add_all([active_session, old_expired_session, recent_expired_session])
session.add_all(
[active_session, old_expired_session, recent_expired_session]
)
await session.commit()
# Mock SessionLocal to return our test session
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30)
# Should only delete old_expired_session
@@ -85,7 +95,7 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user):
"""Test cleanup when no sessions meet deletion criteria."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
active = UserSession(
@@ -95,15 +105,19 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
created_at=datetime.now(timezone.utc),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
created_at=datetime.now(UTC),
last_used_at=datetime.now(UTC),
)
session.add(active)
await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30)
assert deleted_count == 0
@@ -111,10 +125,14 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_empty_database(self, async_test_db):
"""Test cleanup with no sessions in database."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30)
assert deleted_count == 0
@@ -122,7 +140,7 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user):
"""Test cleanup with keep_days=0 deletes all inactive expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
today_expired = UserSession(
@@ -132,15 +150,19 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
created_at=datetime.now(timezone.utc) - timedelta(hours=2),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(hours=1),
created_at=datetime.now(UTC) - timedelta(hours=2),
last_used_at=datetime.now(UTC),
)
session.add(today_expired)
await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=0)
assert deleted_count == 1
@@ -148,7 +170,7 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user):
"""Test that cleanup uses bulk DELETE for many sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create 50 expired sessions
async with AsyncTestingSessionLocal() as session:
@@ -161,16 +183,20 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
created_at=datetime.now(timezone.utc) - timedelta(days=40),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=10),
created_at=datetime.now(UTC) - timedelta(days=40),
last_used_at=datetime.now(UTC),
)
sessions_to_add.append(expired)
session.add_all(sessions_to_add)
await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30)
assert deleted_count == 50
@@ -178,14 +204,20 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_database_error_returns_zero(self, async_test_db):
"""Test cleanup returns 0 on database errors (doesn't crash)."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Mock session_crud.cleanup_expired to raise error
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup:
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
with patch(
"app.services.session_cleanup.session_crud.cleanup_expired"
) as mock_cleanup:
mock_cleanup.side_effect = Exception("Database connection lost")
from app.services.session_cleanup import cleanup_expired_sessions
# Should not crash, should return 0
deleted_count = await cleanup_expired_sessions(keep_days=30)
@@ -198,7 +230,7 @@ class TestGetSessionStatistics:
@pytest.mark.asyncio
async def test_get_statistics_with_sessions(self, async_test_db, async_test_user):
"""Test getting session statistics with various session types."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# 2 active, not expired
@@ -210,9 +242,9 @@ class TestGetSessionStatistics:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
created_at=datetime.now(timezone.utc),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
created_at=datetime.now(UTC),
last_used_at=datetime.now(UTC),
)
session.add(active)
@@ -225,9 +257,9 @@ class TestGetSessionStatistics:
ip_address="192.168.1.2",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
created_at=datetime.now(timezone.utc) - timedelta(days=2),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=1),
created_at=datetime.now(UTC) - timedelta(days=2),
last_used_at=datetime.now(UTC),
)
session.add(inactive)
@@ -239,16 +271,20 @@ class TestGetSessionStatistics:
ip_address="192.168.1.3",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
created_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(hours=1),
created_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC),
)
session.add(expired_active)
await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import get_session_statistics
stats = await get_session_statistics()
assert stats["total"] == 6
@@ -259,10 +295,14 @@ class TestGetSessionStatistics:
@pytest.mark.asyncio
async def test_get_statistics_empty_database(self, async_test_db):
"""Test getting statistics with no sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import get_session_statistics
stats = await get_session_statistics()
assert stats["total"] == 0
@@ -271,9 +311,11 @@ class TestGetSessionStatistics:
assert stats["expired"] == 0
@pytest.mark.asyncio
async def test_get_statistics_database_error_returns_empty_dict(self, async_test_db):
async def test_get_statistics_database_error_returns_empty_dict(
self, async_test_db
):
"""Test statistics returns empty dict on database errors."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, _AsyncTestingSessionLocal = async_test_db
# Create a mock that raises on execute
mock_session = AsyncMock()
@@ -283,8 +325,12 @@ class TestGetSessionStatistics:
async def mock_session_local():
yield mock_session
with patch('app.services.session_cleanup.SessionLocal', return_value=mock_session_local()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=mock_session_local(),
):
from app.services.session_cleanup import get_session_statistics
stats = await get_session_statistics()
assert stats == {}
@@ -294,9 +340,11 @@ class TestConcurrentCleanup:
"""Tests for concurrent cleanup scenarios."""
@pytest.mark.asyncio
async def test_concurrent_cleanup_no_duplicate_deletes(self, async_test_db, async_test_user):
async def test_concurrent_cleanup_no_duplicate_deletes(
self, async_test_db, async_test_user
):
"""Test concurrent cleanups don't cause race conditions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create 10 expired sessions
async with AsyncTestingSessionLocal() as session:
@@ -308,20 +356,24 @@ class TestConcurrentCleanup:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
created_at=datetime.now(timezone.utc) - timedelta(days=40),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=10),
created_at=datetime.now(UTC) - timedelta(days=40),
last_used_at=datetime.now(UTC),
)
session.add(expired)
await session.commit()
# Run two cleanups concurrently
# Use side_effect to return fresh session instances for each call
with patch('app.services.session_cleanup.SessionLocal', side_effect=lambda: AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
side_effect=lambda: AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
results = await asyncio.gather(
cleanup_expired_sessions(keep_days=30),
cleanup_expired_sessions(keep_days=30)
cleanup_expired_sessions(keep_days=30),
)
# Both should report deleting sessions (may overlap due to transaction timing)

View File

@@ -2,12 +2,13 @@
"""
Tests for database initialization script.
"""
import pytest
import pytest_asyncio
from unittest.mock import AsyncMock, patch
from app.init_db import init_db
from unittest.mock import patch
import pytest
from app.core.config import settings
from app.init_db import init_db
class TestInitDb:
@@ -16,69 +17,86 @@ class TestInitDb:
@pytest.mark.asyncio
async def test_init_db_creates_superuser_when_not_exists(self, async_test_db):
"""Test that init_db creates a superuser when one doesn't exist."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Mock the SessionLocal to use our test database
with patch('app.init_db.SessionLocal', SessionLocal):
with patch("app.init_db.SessionLocal", SessionLocal):
# Mock settings to provide test credentials
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test_admin@example.com'):
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestAdmin123!'):
with patch.object(
settings, "FIRST_SUPERUSER_EMAIL", "test_admin@example.com"
):
with patch.object(
settings, "FIRST_SUPERUSER_PASSWORD", "TestAdmin123!"
):
# Run init_db
user = await init_db()
# Verify superuser was created
assert user is not None
assert user.email == 'test_admin@example.com'
assert user.email == "test_admin@example.com"
assert user.is_superuser is True
assert user.first_name == 'Admin'
assert user.last_name == 'User'
assert user.first_name == "Admin"
assert user.last_name == "User"
@pytest.mark.asyncio
async def test_init_db_returns_existing_superuser(self, async_test_db, async_test_user):
async def test_init_db_returns_existing_superuser(
self, async_test_db, async_test_user
):
"""Test that init_db returns existing superuser instead of creating duplicate."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Mock the SessionLocal to use our test database
with patch('app.init_db.SessionLocal', SessionLocal):
with patch("app.init_db.SessionLocal", SessionLocal):
# Mock settings to match async_test_user's email
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'testuser@example.com'):
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
with patch.object(
settings, "FIRST_SUPERUSER_EMAIL", "testuser@example.com"
):
with patch.object(
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
):
# Run init_db
user = await init_db()
# Verify it returns the existing user
assert user is not None
assert user.id == async_test_user.id
assert user.email == 'testuser@example.com'
assert user.email == "testuser@example.com"
@pytest.mark.asyncio
async def test_init_db_uses_default_credentials(self, async_test_db):
"""Test that init_db uses default credentials when env vars not set."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Mock the SessionLocal to use our test database
with patch('app.init_db.SessionLocal', SessionLocal):
with patch("app.init_db.SessionLocal", SessionLocal):
# Mock settings to have None values (not configured)
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', None):
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', None):
with patch.object(settings, "FIRST_SUPERUSER_EMAIL", None):
with patch.object(settings, "FIRST_SUPERUSER_PASSWORD", None):
# Run init_db
user = await init_db()
# Verify superuser was created with defaults
assert user is not None
assert user.email == 'admin@example.com'
assert user.email == "admin@example.com"
assert user.is_superuser is True
@pytest.mark.asyncio
async def test_init_db_handles_database_errors(self, async_test_db):
"""Test that init_db handles database errors gracefully."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Mock user_crud.get_by_email to raise an exception
with patch('app.init_db.user_crud.get_by_email', side_effect=Exception("Database error")):
with patch('app.init_db.SessionLocal', SessionLocal):
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test@example.com'):
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
with patch(
"app.init_db.user_crud.get_by_email",
side_effect=Exception("Database error"),
):
with patch("app.init_db.SessionLocal", SessionLocal):
with patch.object(
settings, "FIRST_SUPERUSER_EMAIL", "test@example.com"
):
with patch.object(
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
):
# Run init_db and expect it to raise
with pytest.raises(Exception, match="Database error"):
await init_db()

View File

@@ -2,18 +2,18 @@
"""
Comprehensive tests for device utility functions.
"""
import pytest
from unittest.mock import Mock
from fastapi import Request
from app.utils.device import (
extract_device_info,
parse_device_name,
extract_browser,
extract_device_info,
get_client_ip,
get_device_type,
is_mobile_device,
get_device_type
parse_device_name,
)
@@ -138,7 +138,9 @@ class TestExtractBrowser:
def test_extract_browser_edge_legacy(self):
"""Test extracting legacy Edge browser."""
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
ua = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
)
result = extract_browser(ua)
assert result == "Edge"
@@ -249,7 +251,7 @@ class TestGetClientIp:
request = Mock(spec=Request)
request.headers = {
"x-forwarded-for": "192.168.1.100",
"x-real-ip": "192.168.1.200"
"x-real-ip": "192.168.1.200",
}
request.client = Mock()
request.client.host = "192.168.1.50"
@@ -385,7 +387,7 @@ class TestExtractDeviceInfo:
request.headers = {
"user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)",
"x-device-id": "device-123-456",
"x-forwarded-for": "192.168.1.100"
"x-forwarded-for": "192.168.1.100",
}
request.client = None

View File

@@ -2,19 +2,21 @@
"""
Tests for security utility functions.
"""
import time
import base64
import json
import time
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import patch, MagicMock
from app.utils.security import (
create_upload_token,
verify_upload_token,
create_password_reset_token,
verify_password_reset_token,
create_email_verification_token,
verify_email_verification_token
create_password_reset_token,
create_upload_token,
verify_email_verification_token,
verify_password_reset_token,
verify_upload_token,
)
@@ -31,7 +33,7 @@ class TestCreateUploadToken:
# Token should be base64 encoded
try:
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
assert "payload" in token_data
assert "signature" in token_data
@@ -46,7 +48,7 @@ class TestCreateUploadToken:
token = create_upload_token(file_path, content_type)
# Decode and verify payload
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
payload = token_data["payload"]
@@ -62,7 +64,7 @@ class TestCreateUploadToken:
after = int(time.time())
# Decode token
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
payload = token_data["payload"]
@@ -74,11 +76,13 @@ class TestCreateUploadToken:
"""Test token creation with custom expiration time."""
custom_exp = 600 # 10 minutes
before = int(time.time())
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp)
token = create_upload_token(
"/uploads/test.jpg", "image/jpeg", expires_in=custom_exp
)
after = int(time.time())
# Decode token
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
payload = token_data["payload"]
@@ -92,11 +96,11 @@ class TestCreateUploadToken:
token2 = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode both tokens
decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8'))
decoded1 = base64.urlsafe_b64decode(token1.encode("utf-8"))
token_data1 = json.loads(decoded1)
nonce1 = token_data1["payload"]["nonce"]
decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8'))
decoded2 = base64.urlsafe_b64decode(token2.encode("utf-8"))
token_data2 = json.loads(decoded2)
nonce2 = token_data2["payload"]["nonce"]
@@ -133,7 +137,7 @@ class TestVerifyUploadToken:
current_time = 1000000
mock_time.time = MagicMock(return_value=current_time)
with patch('app.utils.security.time', mock_time):
with patch("app.utils.security.time", mock_time):
# Create token that "expires" at current_time + 1
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1)
@@ -149,13 +153,15 @@ class TestVerifyUploadToken:
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode, modify, and re-encode
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
token_data["signature"] = "invalid_signature"
# Re-encode the tampered token
tampered_json = json.dumps(token_data)
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
"utf-8"
)
payload = verify_upload_token(tampered_token)
assert payload is None
@@ -165,13 +171,15 @@ class TestVerifyUploadToken:
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode, modify payload, and re-encode
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
token_data["payload"]["path"] = "/uploads/hacked.exe"
# Re-encode the tampered token (signature won't match)
tampered_json = json.dumps(token_data)
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
"utf-8"
)
payload = verify_upload_token(tampered_token)
assert payload is None
@@ -194,7 +202,9 @@ class TestVerifyUploadToken:
"""Test that tokens with invalid JSON are rejected."""
# Create a base64 string that decodes to invalid JSON
invalid_json = "not valid json"
invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8')
invalid_token = base64.urlsafe_b64encode(invalid_json.encode("utf-8")).decode(
"utf-8"
)
payload = verify_upload_token(invalid_token)
assert payload is None
@@ -207,11 +217,13 @@ class TestVerifyUploadToken:
"path": "/uploads/test.jpg"
# Missing content_type, exp, nonce
},
"signature": "some_signature"
"signature": "some_signature",
}
incomplete_json = json.dumps(incomplete_data)
incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8')
incomplete_token = base64.urlsafe_b64encode(
incomplete_json.encode("utf-8")
).decode("utf-8")
payload = verify_upload_token(incomplete_token)
assert payload is None
@@ -266,7 +278,7 @@ class TestPasswordResetTokens:
email = "user@example.com"
# Create token that expires in 1 second
with patch('app.utils.security.time') as mock_time:
with patch("app.utils.security.time") as mock_time:
mock_time.time = MagicMock(return_value=1000000)
token = create_password_reset_token(email, expires_in=1)
@@ -287,12 +299,14 @@ class TestPasswordResetTokens:
token = create_password_reset_token(email)
# Decode and tamper
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com"
# Re-encode
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
tampered = base64.urlsafe_b64encode(
json.dumps(token_data).encode("utf-8")
).decode("utf-8")
verified_email = verify_password_reset_token(tampered)
assert verified_email is None
@@ -312,14 +326,14 @@ class TestPasswordResetTokens:
email = "user@example.com"
custom_exp = 7200 # 2 hours
with patch('app.utils.security.time') as mock_time:
with patch("app.utils.security.time") as mock_time:
current_time = 1000000
mock_time.time = MagicMock(return_value=current_time)
token = create_password_reset_token(email, expires_in=custom_exp)
# Decode to check expiration
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
assert token_data["payload"]["exp"] == current_time + custom_exp
@@ -350,7 +364,7 @@ class TestEmailVerificationTokens:
"""Test that expired verification tokens are rejected."""
email = "user@example.com"
with patch('app.utils.security.time') as mock_time:
with patch("app.utils.security.time") as mock_time:
mock_time.time = MagicMock(return_value=1000000)
token = create_email_verification_token(email, expires_in=1)
@@ -371,12 +385,14 @@ class TestEmailVerificationTokens:
token = create_email_verification_token(email)
# Decode and tamper
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com"
# Re-encode
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
tampered = base64.urlsafe_b64encode(
json.dumps(token_data).encode("utf-8")
).decode("utf-8")
verified_email = verify_email_verification_token(tampered)
assert verified_email is None
@@ -395,14 +411,14 @@ class TestEmailVerificationTokens:
"""Test email verification token with default 24-hour expiration."""
email = "user@example.com"
with patch('app.utils.security.time') as mock_time:
with patch("app.utils.security.time") as mock_time:
current_time = 1000000
mock_time.time = MagicMock(return_value=current_time)
token = create_email_verification_token(email)
# Decode to check expiration (should be 86400 seconds = 24 hours)
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
assert token_data["payload"]["exp"] == current_time + 86400