forked from cardosofelipe/fast-next-template
Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff
- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest). - Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting. - Updated `requirements.txt` to include Ruff and remove replaced tools. - Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
# tests/api/dependencies/test_auth_dependencies.py
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.api.dependencies.auth import (
|
||||
get_current_user,
|
||||
get_current_active_user,
|
||||
get_current_superuser,
|
||||
get_optional_current_user
|
||||
get_current_user,
|
||||
get_optional_current_user,
|
||||
)
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
|
||||
from app.models.user import User
|
||||
@@ -24,7 +25,7 @@ def mock_token():
|
||||
@pytest_asyncio.fixture
|
||||
async def async_mock_user(async_test_db):
|
||||
"""Async fixture to create and return a mock User instance."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
|
||||
"""Tests for get_current_user dependency"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_current_user_success(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test successfully getting the current user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to return user_id that matches our mock_user
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
|
||||
"""Test when the token contains a user ID that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to return a non-existent user ID
|
||||
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = nonexistent_id
|
||||
|
||||
# Should raise HTTPException with 404 status
|
||||
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
|
||||
assert "User not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_current_user_inactive(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test when the user is inactive"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
from sqlalchemy import select
|
||||
result = await session.execute(select(User).where(User.id == async_mock_user.id))
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_mock_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Should raise HTTPException with 403 status
|
||||
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_expired_token(self, async_test_db, mock_token):
|
||||
"""Test with an expired token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
|
||||
"""Test with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
|
||||
"""Tests for get_optional_current_user dependency"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_optional_current_user_with_token(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test getting optional user with a valid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_no_token(self, async_test_db):
|
||||
"""Test getting optional user with no token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Call the dependency with no token
|
||||
user = await get_optional_current_user(db=session, token=None)
|
||||
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token):
|
||||
async def test_get_optional_current_user_invalid_token(
|
||||
self, async_test_db, mock_token
|
||||
):
|
||||
"""Test getting optional user with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Call the dependency
|
||||
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token):
|
||||
async def test_get_optional_current_user_expired_token(
|
||||
self, async_test_db, mock_token
|
||||
):
|
||||
"""Test getting optional user with an expired token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Call the dependency
|
||||
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_optional_current_user_inactive(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test getting optional user when user is inactive"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
from sqlalchemy import select
|
||||
result = await session.execute(select(User).where(User.id == async_mock_user.id))
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_mock_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
# tests/api/routes/test_health.py
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import get_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -121,7 +120,10 @@ class TestHealthEndpoint:
|
||||
response = client.get("/health")
|
||||
|
||||
# Should succeed without authentication
|
||||
assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE]
|
||||
assert response.status_code in [
|
||||
status.HTTP_200_OK,
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
]
|
||||
|
||||
def test_health_check_idempotent(self, client):
|
||||
"""Test that multiple health checks return consistent results"""
|
||||
@@ -142,7 +144,10 @@ class TestHealthEndpoint:
|
||||
assert data1["environment"] == data2["environment"]
|
||||
|
||||
# Same database check status
|
||||
assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"]
|
||||
assert (
|
||||
data1["checks"]["database"]["status"]
|
||||
== data2["checks"]["database"]["status"]
|
||||
)
|
||||
|
||||
def test_health_check_content_type(self, client):
|
||||
"""Test that health check returns JSON content type"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
"""
|
||||
Tests for authentication endpoints.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
@@ -19,8 +20,8 @@ class TestRegisterEndpoint:
|
||||
"email": "newuser@example.com",
|
||||
"password": "NewPassword123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
@@ -36,8 +37,8 @@ class TestRegisterEndpoint:
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!",
|
||||
"first_name": "Test",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -51,8 +52,8 @@ class TestRegisterEndpoint:
|
||||
"email": "test@example.com",
|
||||
"password": "weak",
|
||||
"first_name": "Test",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -66,10 +67,7 @@ class TestLoginEndpoint:
|
||||
"""Test successful login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -82,10 +80,7 @@ class TestLoginEndpoint:
|
||||
"""Test login with invalid password."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "WrongPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "WrongPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -95,10 +90,7 @@ class TestLoginEndpoint:
|
||||
"""Test login with non-existent user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "nonexistent@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "nonexistent@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -106,27 +98,25 @@ class TestLoginEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_inactive_user(self, client, async_test_db):
|
||||
"""Test login with inactive user."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user import User
|
||||
|
||||
inactive_user = User(
|
||||
email="inactive@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
is_active=False
|
||||
is_active=False,
|
||||
)
|
||||
session.add(inactive_user)
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "inactive@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "inactive@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -140,10 +130,7 @@ class TestRefreshTokenEndpoint:
|
||||
"""Get a refresh token for testing."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
return response.json()["refresh_token"]
|
||||
|
||||
@@ -151,8 +138,7 @@ class TestRefreshTokenEndpoint:
|
||||
async def test_refresh_token_success(self, client, refresh_token):
|
||||
"""Test successful token refresh."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -164,8 +150,7 @@ class TestRefreshTokenEndpoint:
|
||||
async def test_refresh_token_invalid(self, client):
|
||||
"""Test refresh with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid.token.here"}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": "invalid.token.here"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -179,13 +164,13 @@ class TestLogoutEndpoint:
|
||||
"""Get tokens for testing."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
data = response.json()
|
||||
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
|
||||
return {
|
||||
"access_token": data["access_token"],
|
||||
"refresh_token": data["refresh_token"],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_success(self, client, tokens):
|
||||
@@ -193,7 +178,7 @@ class TestLogoutEndpoint:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -202,8 +187,7 @@ class TestLogoutEndpoint:
|
||||
async def test_logout_without_auth(self, client):
|
||||
"""Test logout without authentication."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "some.token"}
|
||||
"/api/v1/auth/logout", json={"refresh_token": "some.token"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -215,8 +199,7 @@ class TestPasswordResetRequest:
|
||||
async def test_password_reset_request_success(self, client, async_test_user):
|
||||
"""Test password reset request with existing user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
"/api/v1/auth/password-reset/request", json={"email": async_test_user.email}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -228,7 +211,7 @@ class TestPasswordResetRequest:
|
||||
"""Test password reset request with non-existent email."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": "nonexistent@example.com"}
|
||||
json={"email": "nonexistent@example.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -244,10 +227,7 @@ class TestPasswordResetConfirm:
|
||||
"""Test password reset with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": "invalid.token.here",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
json={"token": "invalid.token.here", "new_password": "NewPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -261,20 +241,20 @@ class TestLogoutAll:
|
||||
"""Get tokens for testing."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
data = response.json()
|
||||
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
|
||||
return {
|
||||
"access_token": data["access_token"],
|
||||
"refresh_token": data["refresh_token"],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_all_success(self, client, tokens):
|
||||
"""Test logout from all devices."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -298,10 +278,7 @@ class TestOAuthLogin:
|
||||
"""Test successful OAuth login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
data={"username": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -315,10 +292,7 @@ class TestOAuthLogin:
|
||||
"""Test OAuth login with invalid credentials."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": "testuser@example.com",
|
||||
"password": "WrongPassword"
|
||||
}
|
||||
data={"username": "testuser@example.com", "password": "WrongPassword"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# tests/api/dependencies/test_auth_dependencies.py
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.api.dependencies.auth import (
|
||||
get_current_user,
|
||||
get_current_active_user,
|
||||
get_current_superuser,
|
||||
get_optional_current_user
|
||||
get_current_user,
|
||||
get_optional_current_user,
|
||||
)
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
|
||||
from app.models.user import User
|
||||
@@ -24,7 +25,7 @@ def mock_token():
|
||||
@pytest_asyncio.fixture
|
||||
async def async_mock_user(async_test_db):
|
||||
"""Async fixture to create and return a mock User instance."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
|
||||
"""Tests for get_current_user dependency"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_current_user_success(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test successfully getting the current user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to return user_id that matches our mock_user
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
|
||||
"""Test when the token contains a user ID that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to return a non-existent user ID
|
||||
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = nonexistent_id
|
||||
|
||||
# Should raise HTTPException with 404 status
|
||||
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
|
||||
assert "User not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_current_user_inactive(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test when the user is inactive"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
from sqlalchemy import select
|
||||
result = await session.execute(select(User).where(User.id == async_mock_user.id))
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_mock_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Should raise HTTPException with 403 status
|
||||
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_expired_token(self, async_test_db, mock_token):
|
||||
"""Test with an expired token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
|
||||
"""Test with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
|
||||
"""Tests for get_optional_current_user dependency"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_optional_current_user_with_token(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test getting optional user with a valid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_no_token(self, async_test_db):
|
||||
"""Test getting optional user with no token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Call the dependency with no token
|
||||
user = await get_optional_current_user(db=session, token=None)
|
||||
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token):
|
||||
async def test_get_optional_current_user_invalid_token(
|
||||
self, async_test_db, mock_token
|
||||
):
|
||||
"""Test getting optional user with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Call the dependency
|
||||
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token):
|
||||
async def test_get_optional_current_user_expired_token(
|
||||
self, async_test_db, mock_token
|
||||
):
|
||||
"""Test getting optional user with an expired token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Call the dependency
|
||||
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_optional_current_user_inactive(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test getting optional user when user is inactive"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
from sqlalchemy import select
|
||||
result = await session.execute(select(User).where(User.id == async_mock_user.id))
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_mock_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
|
||||
@@ -2,21 +2,21 @@
|
||||
"""
|
||||
Tests for authentication endpoints.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi import status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
# Disable rate limiting for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limit():
|
||||
"""Disable rate limiting for all tests in this module."""
|
||||
with patch('app.api.routes.auth.limiter.enabled', False):
|
||||
with patch("app.api.routes.auth.limiter.enabled", False):
|
||||
yield
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ class TestRegisterEndpoint:
|
||||
"email": "newuser@example.com",
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
@@ -54,8 +54,8 @@ class TestRegisterEndpoint:
|
||||
"email": async_test_user.email,
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "Duplicate",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Security: Returns 400 with generic message to prevent email enumeration
|
||||
@@ -73,8 +73,8 @@ class TestRegisterEndpoint:
|
||||
"email": "weakpass@example.com",
|
||||
"password": "weak",
|
||||
"first_name": "Weak",
|
||||
"last_name": "Pass"
|
||||
}
|
||||
"last_name": "Pass",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -82,7 +82,7 @@ class TestRegisterEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unexpected_error(self, client):
|
||||
"""Test registration with unexpected error."""
|
||||
with patch('app.services.auth_service.AuthService.create_user') as mock_create:
|
||||
with patch("app.services.auth_service.AuthService.create_user") as mock_create:
|
||||
mock_create.side_effect = Exception("Unexpected error")
|
||||
|
||||
response = await client.post(
|
||||
@@ -91,8 +91,8 @@ class TestRegisterEndpoint:
|
||||
"email": "error@example.com",
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "Error",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -106,10 +106,7 @@ class TestLoginEndpoint:
|
||||
"""Test successful login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -123,10 +120,7 @@ class TestLoginEndpoint:
|
||||
"""Test login with wrong password."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "WrongPassword123"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "WrongPassword123"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -136,10 +130,7 @@ class TestLoginEndpoint:
|
||||
"""Test login with non-existent email."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "nonexistent@example.com",
|
||||
"password": "Password123!"
|
||||
}
|
||||
json={"email": "nonexistent@example.com", "password": "Password123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -147,20 +138,19 @@ class TestLoginEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_inactive_user(self, client, async_test_user, async_test_db):
|
||||
"""Test login with inactive user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -168,15 +158,14 @@ class TestLoginEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_unexpected_error(self, client, async_test_user):
|
||||
"""Test login with unexpected error."""
|
||||
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
|
||||
with patch(
|
||||
"app.services.auth_service.AuthService.authenticate_user"
|
||||
) as mock_auth:
|
||||
mock_auth.side_effect = Exception("Database error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -190,10 +179,7 @@ class TestOAuthLoginEndpoint:
|
||||
"""Test successful OAuth login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
data={"username": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -206,31 +192,29 @@ class TestOAuthLoginEndpoint:
|
||||
"""Test OAuth login with wrong credentials."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "WrongPassword"
|
||||
}
|
||||
data={"username": async_test_user.email, "password": "WrongPassword"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_login_inactive_user(self, client, async_test_user, async_test_db):
|
||||
async def test_oauth_login_inactive_user(
|
||||
self, client, async_test_user, async_test_db
|
||||
):
|
||||
"""Test OAuth login with inactive user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
data={"username": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -238,15 +222,17 @@ class TestOAuthLoginEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_login_unexpected_error(self, client, async_test_user):
|
||||
"""Test OAuth login with unexpected error."""
|
||||
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
|
||||
with patch(
|
||||
"app.services.auth_service.AuthService.authenticate_user"
|
||||
) as mock_auth:
|
||||
mock_auth.side_effect = Exception("Unexpected error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
"password": "TestPassword123!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -261,17 +247,13 @@ class TestRefreshTokenEndpoint:
|
||||
# First, login to get a refresh token
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
refresh_token = login_response.json()["refresh_token"]
|
||||
|
||||
# Now refresh the token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -284,12 +266,13 @@ class TestRefreshTokenEndpoint:
|
||||
"""Test refresh with expired token."""
|
||||
from app.core.auth import TokenExpiredError
|
||||
|
||||
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
|
||||
with patch(
|
||||
"app.services.auth_service.AuthService.refresh_tokens"
|
||||
) as mock_refresh:
|
||||
mock_refresh.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "some_token"}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": "some_token"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -298,8 +281,7 @@ class TestRefreshTokenEndpoint:
|
||||
async def test_refresh_token_invalid(self, client):
|
||||
"""Test refresh with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid_token"}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": "invalid_token"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -310,19 +292,17 @@ class TestRefreshTokenEndpoint:
|
||||
# Get a valid refresh token first
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
refresh_token = login_response.json()["refresh_token"]
|
||||
|
||||
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
|
||||
with patch(
|
||||
"app.services.auth_service.AuthService.refresh_tokens"
|
||||
) as mock_refresh:
|
||||
mock_refresh.side_effect = Exception("Unexpected error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
"""
|
||||
Tests for auth route exception handlers and error paths.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from fastapi import status
|
||||
|
||||
|
||||
@@ -11,16 +13,18 @@ class TestLoginSessionCreationFailure:
|
||||
"""Test login when session creation fails."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_succeeds_despite_session_creation_failure(self, client, async_test_user):
|
||||
async def test_login_succeeds_despite_session_creation_failure(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test that login succeeds even if session creation fails."""
|
||||
# Mock session creation to fail
|
||||
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session creation failed")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.create_session",
|
||||
side_effect=Exception("Session creation failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
# Login should still succeed, just without session record
|
||||
@@ -34,15 +38,20 @@ class TestOAuthLoginSessionCreationFailure:
|
||||
"""Test OAuth login when session creation fails."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_login_succeeds_despite_session_failure(self, client, async_test_user):
|
||||
async def test_oauth_login_succeeds_despite_session_failure(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test OAuth login succeeds even if session creation fails."""
|
||||
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session failed")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.create_session",
|
||||
side_effect=Exception("Session failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
"password": "TestPassword123!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -54,23 +63,24 @@ class TestRefreshTokenSessionUpdateFailure:
|
||||
"""Test refresh token when session update fails."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_succeeds_despite_session_update_failure(self, client, async_test_user):
|
||||
async def test_refresh_token_succeeds_despite_session_update_failure(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test that token refresh succeeds even if session update fails."""
|
||||
# First login to get tokens
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
# Mock session update to fail
|
||||
with patch('app.api.routes.auth.session_crud.update_refresh_token', side_effect=Exception("Update failed")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.update_refresh_token",
|
||||
side_effect=Exception("Update failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": tokens["refresh_token"]}
|
||||
)
|
||||
|
||||
# Should still succeed - tokens are issued before update
|
||||
@@ -83,15 +93,14 @@ class TestLogoutWithExpiredToken:
|
||||
"""Test logout with expired/invalid token."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_with_invalid_token_still_succeeds(self, client, async_test_user):
|
||||
async def test_logout_with_invalid_token_still_succeeds(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test logout succeeds even with invalid refresh token."""
|
||||
# Login first
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
access_token = response.json()["access_token"]
|
||||
|
||||
@@ -99,7 +108,7 @@ class TestLogoutWithExpiredToken:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"refresh_token": "invalid.token.here"}
|
||||
json={"refresh_token": "invalid.token.here"},
|
||||
)
|
||||
|
||||
# Should succeed (idempotent)
|
||||
@@ -116,19 +125,16 @@ class TestLogoutWithNonExistentSession:
|
||||
"""Test logout succeeds even if session not found."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
# Mock session lookup to return None
|
||||
with patch('app.api.routes.auth.session_crud.get_by_jti', return_value=None):
|
||||
with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
|
||||
# Should succeed (idempotent)
|
||||
@@ -139,23 +145,25 @@ class TestLogoutUnexpectedError:
|
||||
"""Test logout with unexpected errors."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_with_unexpected_error_returns_success(self, client, async_test_user):
|
||||
async def test_logout_with_unexpected_error_returns_success(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test logout returns success even on unexpected errors."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
# Mock to raise unexpected error
|
||||
with patch('app.api.routes.auth.session_crud.get_by_jti', side_effect=Exception("Unexpected error")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.get_by_jti",
|
||||
side_effect=Exception("Unexpected error"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
|
||||
# Should still return success (don't expose errors)
|
||||
@@ -172,18 +180,18 @@ class TestLogoutAllUnexpectedError:
|
||||
"""Test logout-all handles database errors."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
access_token = response.json()["access_token"]
|
||||
|
||||
# Mock to raise database error
|
||||
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("DB error")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {access_token}"}
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -193,7 +201,9 @@ class TestPasswordResetConfirmSessionInvalidation:
|
||||
"""Test password reset invalidates sessions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_continues_despite_session_invalidation_failure(self, client, async_test_user):
|
||||
async def test_password_reset_continues_despite_session_invalidation_failure(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test password reset succeeds even if session invalidation fails."""
|
||||
# Create a valid password reset token
|
||||
from app.utils.security import create_password_reset_token
|
||||
@@ -201,13 +211,13 @@ class TestPasswordResetConfirmSessionInvalidation:
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
|
||||
# Mock session invalidation to fail
|
||||
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("Invalidation failed")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
||||
side_effect=Exception("Invalidation failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewPassword123!"},
|
||||
)
|
||||
|
||||
# Should still succeed - password was reset
|
||||
|
||||
@@ -2,22 +2,22 @@
|
||||
"""
|
||||
Tests for password reset endpoints.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
from fastapi import status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.schemas.users import PasswordResetRequest, PasswordResetConfirm
|
||||
from app.utils.security import create_password_reset_token
|
||||
from app.models.user import User
|
||||
from app.utils.security import create_password_reset_token
|
||||
|
||||
|
||||
# Disable rate limiting for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limit():
|
||||
"""Disable rate limiting for all tests in this module."""
|
||||
with patch('app.api.routes.auth.limiter.enabled', False):
|
||||
with patch("app.api.routes.auth.limiter.enabled", False):
|
||||
yield
|
||||
|
||||
|
||||
@@ -27,12 +27,14 @@ class TestPasswordResetRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_valid_email(self, client, async_test_user):
|
||||
"""Test password reset request with valid email."""
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
mock_send.return_value = True
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -50,10 +52,12 @@ class TestPasswordResetRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_nonexistent_email(self, client):
|
||||
"""Test password reset request with non-existent email."""
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": "nonexistent@example.com"}
|
||||
json={"email": "nonexistent@example.com"},
|
||||
)
|
||||
|
||||
# Should still return success to prevent email enumeration
|
||||
@@ -65,20 +69,26 @@ class TestPasswordResetRequest:
|
||||
mock_send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_inactive_user(self, client, async_test_db, async_test_user):
|
||||
async def test_password_reset_request_inactive_user(
|
||||
self, client, async_test_db, async_test_user
|
||||
):
|
||||
"""Test password reset request with inactive user."""
|
||||
# Deactivate user
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
|
||||
# Should still return success to prevent email enumeration
|
||||
@@ -93,8 +103,7 @@ class TestPasswordResetRequest:
|
||||
async def test_password_reset_request_invalid_email_format(self, client):
|
||||
"""Test password reset request with invalid email format."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": "not-an-email"}
|
||||
"/api/v1/auth/password-reset/request", json={"email": "not-an-email"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -102,22 +111,23 @@ class TestPasswordResetRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_missing_email(self, client):
|
||||
"""Test password reset request without email."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={}
|
||||
)
|
||||
response = await client.post("/api/v1/auth/password-reset/request", json={})
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_email_service_error(self, client, async_test_user):
|
||||
async def test_password_reset_request_email_service_error(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test password reset when email service fails."""
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
mock_send.side_effect = Exception("SMTP Error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
|
||||
# Should still return success even if email fails
|
||||
@@ -128,14 +138,16 @@ class TestPasswordResetRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_rate_limiting(self, client, async_test_user):
|
||||
"""Test that password reset requests are rate limited."""
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
mock_send.return_value = True
|
||||
|
||||
# Make multiple requests quickly (3/minute limit)
|
||||
for _ in range(3):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@@ -144,7 +156,9 @@ class TestPasswordResetConfirm:
|
||||
"""Tests for POST /auth/password-reset/confirm endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_confirm_valid_token(self, client, async_test_user, async_test_db):
|
||||
async def test_password_reset_confirm_valid_token(
|
||||
self, client, async_test_user, async_test_db
|
||||
):
|
||||
"""Test password reset confirmation with valid token."""
|
||||
# Generate valid token
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
@@ -152,10 +166,7 @@ class TestPasswordResetConfirm:
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": new_password
|
||||
}
|
||||
json={"token": token, "new_password": new_password},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -164,11 +175,14 @@ class TestPasswordResetConfirm:
|
||||
assert "successfully" in data["message"].lower()
|
||||
|
||||
# Verify user can login with new password
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
updated_user = result.scalar_one_or_none()
|
||||
from app.core.auth import verify_password
|
||||
|
||||
assert verify_password(new_password, updated_user.password_hash) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -184,10 +198,7 @@ class TestPasswordResetConfirm:
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -202,10 +213,7 @@ class TestPasswordResetConfirm:
|
||||
"""Test password reset confirmation with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": "invalid_token_xyz",
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": "invalid_token_xyz", "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -222,19 +230,18 @@ class TestPasswordResetConfirm:
|
||||
|
||||
# Create valid token and tamper with it
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["email"] = "hacker@example.com"
|
||||
|
||||
# Re-encode tampered token
|
||||
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
|
||||
tampered = base64.urlsafe_b64encode(
|
||||
json.dumps(token_data).encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": tampered,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": tampered, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -247,10 +254,7 @@ class TestPasswordResetConfirm:
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -260,12 +264,16 @@ class TestPasswordResetConfirm:
|
||||
assert "not found" in error_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_confirm_inactive_user(self, client, async_test_user, async_test_db):
|
||||
async def test_password_reset_confirm_inactive_user(
|
||||
self, client, async_test_user, async_test_db
|
||||
):
|
||||
"""Test password reset confirmation for inactive user."""
|
||||
# Deactivate user
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
@@ -274,10 +282,7 @@ class TestPasswordResetConfirm:
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -301,10 +306,7 @@ class TestPasswordResetConfirm:
|
||||
for weak_password in weak_passwords:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": weak_password
|
||||
}
|
||||
json={"token": token, "new_password": weak_password},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -315,15 +317,14 @@ class TestPasswordResetConfirm:
|
||||
# Missing token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={"new_password": "NewSecure123!"}
|
||||
json={"new_password": "NewSecure123!"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
# Missing password
|
||||
token = create_password_reset_token("test@example.com")
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={"token": token}
|
||||
"/api/v1/auth/password-reset/confirm", json={"token": token}
|
||||
)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@@ -333,15 +334,12 @@ class TestPasswordResetConfirm:
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
|
||||
# Mock the database commit to raise an exception
|
||||
with patch('app.api.routes.auth.user_crud.get_by_email') as mock_get:
|
||||
with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get:
|
||||
mock_get.side_effect = Exception("Database error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -351,18 +349,22 @@ class TestPasswordResetConfirm:
|
||||
assert "error" in error_msg or "resetting" in error_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_full_flow(self, client, async_test_user, async_test_db):
|
||||
async def test_password_reset_full_flow(
|
||||
self, client, async_test_user, async_test_db
|
||||
):
|
||||
"""Test complete password reset flow."""
|
||||
original_password = async_test_user.password_hash
|
||||
new_password = "BrandNew123!"
|
||||
|
||||
# Step 1: Request password reset
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
mock_send.return_value = True
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -374,29 +376,24 @@ class TestPasswordResetConfirm:
|
||||
# Step 2: Confirm password reset
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": reset_token,
|
||||
"new_password": new_password
|
||||
}
|
||||
json={"token": reset_token, "new_password": new_password},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Step 3: Verify old password doesn't work
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
updated_user = result.scalar_one_or_none()
|
||||
from app.core.auth import verify_password
|
||||
assert updated_user.password_hash != original_password
|
||||
|
||||
# Step 4: Verify new password works
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": new_password
|
||||
}
|
||||
json={"email": async_test_user.email, "password": new_password},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@@ -8,11 +8,10 @@ Critical security tests covering:
|
||||
|
||||
These tests prevent real-world attack scenarios.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import create_refresh_token
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
|
||||
@@ -30,10 +29,7 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_rejected_after_logout(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User
|
||||
):
|
||||
"""
|
||||
Test that refresh tokens are rejected after session is deactivated.
|
||||
@@ -45,10 +41,10 @@ class TestRevokedSessionSecurity:
|
||||
4. Attacker tries to use stolen refresh token
|
||||
5. System MUST reject it (session revoked)
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Step 1: Create a session and refresh token for the user
|
||||
async with SessionLocal() as session:
|
||||
async with SessionLocal():
|
||||
# Login to get tokens
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
@@ -64,8 +60,7 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
# Step 2: Verify refresh token works before logout
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
assert response.status_code == 200, "Refresh should work before logout"
|
||||
|
||||
@@ -73,14 +68,13 @@ class TestRevokedSessionSecurity:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"refresh_token": refresh_token}
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert response.status_code == 200, "Logout should succeed"
|
||||
|
||||
# Step 4: Attacker tries to use stolen refresh token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
# Step 5: System MUST reject (covers lines 261-262)
|
||||
@@ -93,10 +87,7 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_rejected_for_deleted_session(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User
|
||||
):
|
||||
"""
|
||||
Test that tokens for deleted sessions are rejected.
|
||||
@@ -104,7 +95,7 @@ class TestRevokedSessionSecurity:
|
||||
Attack Scenario:
|
||||
Admin deletes a session from database, but attacker has the token.
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Step 1: Login to create a session
|
||||
response = await client.post(
|
||||
@@ -120,6 +111,7 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
# Step 2: Manually delete the session from database (simulating admin action)
|
||||
from app.core.auth import decode_token
|
||||
|
||||
token_data = decode_token(refresh_token, verify_type="refresh")
|
||||
jti = token_data.jti
|
||||
|
||||
@@ -132,15 +124,17 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
# Step 3: Try to use the refresh token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
# Should reject (session doesn't exist)
|
||||
assert response.status_code == 401
|
||||
data = response.json()
|
||||
if "errors" in data:
|
||||
assert "revoked" in data["errors"][0]["message"].lower() or "session" in data["errors"][0]["message"].lower()
|
||||
assert (
|
||||
"revoked" in data["errors"][0]["message"].lower()
|
||||
or "session" in data["errors"][0]["message"].lower()
|
||||
)
|
||||
else:
|
||||
assert "revoked" in data.get("detail", "").lower()
|
||||
|
||||
@@ -162,7 +156,7 @@ class TestSessionHijackingSecurity:
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User,
|
||||
async_test_superuser: User
|
||||
async_test_superuser: User,
|
||||
):
|
||||
"""
|
||||
Test that users cannot logout other users' sessions.
|
||||
@@ -173,7 +167,7 @@ class TestSessionHijackingSecurity:
|
||||
3. User A tries to logout User B's session
|
||||
4. System MUST reject (cross-user attack)
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, _SessionLocal = async_test_db
|
||||
|
||||
# Step 1: User A logs in
|
||||
response = await client.post(
|
||||
@@ -202,8 +196,10 @@ class TestSessionHijackingSecurity:
|
||||
# Step 3: User A tries to logout User B's session using User B's refresh token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {user_a_access}"}, # User A's access token
|
||||
json={"refresh_token": user_b_refresh} # But User B's refresh token
|
||||
headers={
|
||||
"Authorization": f"Bearer {user_a_access}"
|
||||
}, # User A's access token
|
||||
json={"refresh_token": user_b_refresh}, # But User B's refresh token
|
||||
)
|
||||
|
||||
# Step 4: System MUST reject (covers lines 509-513)
|
||||
@@ -217,9 +213,7 @@ class TestSessionHijackingSecurity:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_users_can_logout_their_own_sessions(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_user: User
|
||||
self, client: AsyncClient, async_test_user: User
|
||||
):
|
||||
"""
|
||||
Sanity check: Users CAN logout their own sessions.
|
||||
@@ -241,6 +235,8 @@ class TestSessionHijackingSecurity:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
assert response.status_code == 200, (
|
||||
"Users should be able to logout their own sessions"
|
||||
)
|
||||
assert response.status_code == 200, "Users should be able to logout their own sessions"
|
||||
|
||||
@@ -5,16 +5,18 @@ Tests for organization routes (user endpoints).
|
||||
These test the routes in app/api/routes/organizations.py which allow
|
||||
users to view and manage organizations they belong to.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -22,10 +24,7 @@ async def user_token(client, async_test_user):
|
||||
"""Get access token for regular user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -34,7 +33,7 @@ async def user_token(client, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def second_user(async_test_db):
|
||||
"""Create a second test user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = User(
|
||||
id=uuid4(),
|
||||
@@ -56,12 +55,12 @@ async def second_user(async_test_db):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_user_member(async_test_db, async_test_user):
|
||||
"""Create a test organization with async_test_user as a member."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Member Org",
|
||||
slug="member-org",
|
||||
description="Test organization where user is a member"
|
||||
description="Test organization where user is a member",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -72,7 +71,7 @@ async def test_org_with_user_member(async_test_db, async_test_user):
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -83,12 +82,12 @@ async def test_org_with_user_member(async_test_db, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_user_admin(async_test_db, async_test_user):
|
||||
"""Create a test organization with async_test_user as an admin."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Admin Org",
|
||||
slug="admin-org",
|
||||
description="Test organization where user is an admin"
|
||||
description="Test organization where user is an admin",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -99,7 +98,7 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -110,12 +109,12 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_user_owner(async_test_db, async_test_user):
|
||||
"""Create a test organization with async_test_user as owner."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Owner Org",
|
||||
slug="owner-org",
|
||||
description="Test organization where user is owner"
|
||||
description="Test organization where user is owner",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -126,7 +125,7 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.OWNER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -136,21 +135,18 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
|
||||
|
||||
# ===== GET /api/v1/organizations/me =====
|
||||
|
||||
|
||||
class TestGetMyOrganizations:
|
||||
"""Tests for GET /api/v1/organizations/me endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_organizations_success(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member,
|
||||
test_org_with_user_admin
|
||||
self, client, user_token, test_org_with_user_member, test_org_with_user_admin
|
||||
):
|
||||
"""Test successfully getting user's organizations (covers lines 54-79)."""
|
||||
response = await client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -167,21 +163,15 @@ class TestGetMyOrganizations:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_organizations_filter_active(
|
||||
self,
|
||||
client,
|
||||
async_test_db,
|
||||
async_test_user,
|
||||
user_token
|
||||
self, client, async_test_db, async_test_user, user_token
|
||||
):
|
||||
"""Test filtering organizations by active status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active org
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_org = Organization(
|
||||
name="Active Org",
|
||||
slug="active-org-filter",
|
||||
is_active=True
|
||||
name="Active Org", slug="active-org-filter", is_active=True
|
||||
)
|
||||
session.add(active_org)
|
||||
await session.commit()
|
||||
@@ -192,14 +182,14 @@ class TestGetMyOrganizations:
|
||||
user_id=async_test_user.id,
|
||||
organization_id=active_org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/organizations/me?is_active=true",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -209,7 +199,7 @@ class TestGetMyOrganizations:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_organizations_empty(self, client, async_test_db):
|
||||
"""Test getting organizations when user has none."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user with no org memberships
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -219,7 +209,7 @@ class TestGetMyOrganizations:
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="No",
|
||||
last_name="Org",
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
@@ -227,13 +217,12 @@ class TestGetMyOrganizations:
|
||||
# Login to get token
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "noorg@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "noorg@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
"/api/v1/organizations/me", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -243,20 +232,18 @@ class TestGetMyOrganizations:
|
||||
|
||||
# ===== GET /api/v1/organizations/{organization_id} =====
|
||||
|
||||
|
||||
class TestGetOrganization:
|
||||
"""Tests for GET /api/v1/organizations/{organization_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_success(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test successfully getting organization details (covers lines 103-122)."""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -272,7 +259,7 @@ class TestGetOrganization:
|
||||
fake_org_id = uuid4()
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{fake_org_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Permission dependency checks membership before endpoint logic
|
||||
@@ -283,20 +270,14 @@ class TestGetOrganization:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_not_member(
|
||||
self,
|
||||
client,
|
||||
async_test_db,
|
||||
async_test_user
|
||||
self, client, async_test_db, async_test_user
|
||||
):
|
||||
"""Test getting organization where user is not a member fails."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create org without adding user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Not Member Org",
|
||||
slug="not-member-org"
|
||||
)
|
||||
org = Organization(name="Not Member Org", slug="not-member-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -305,13 +286,13 @@ class TestGetOrganization:
|
||||
# Login as user
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
# Should fail permission check
|
||||
@@ -320,6 +301,7 @@ class TestGetOrganization:
|
||||
|
||||
# ===== GET /api/v1/organizations/{organization_id}/members =====
|
||||
|
||||
|
||||
class TestGetOrganizationMembers:
|
||||
"""Tests for GET /api/v1/organizations/{organization_id}/members endpoint."""
|
||||
|
||||
@@ -331,10 +313,10 @@ class TestGetOrganizationMembers:
|
||||
async_test_user,
|
||||
second_user,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
test_org_with_user_member,
|
||||
):
|
||||
"""Test successfully getting organization members (covers lines 150-168)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Add second user to org
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -342,14 +324,14 @@ class TestGetOrganizationMembers:
|
||||
user_id=second_user.id,
|
||||
organization_id=test_org_with_user_member.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}/members",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -360,15 +342,12 @@ class TestGetOrganizationMembers:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_members_with_pagination(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test pagination parameters."""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}/members?page=1&limit=10",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -385,10 +364,10 @@ class TestGetOrganizationMembers:
|
||||
async_test_user,
|
||||
second_user,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
test_org_with_user_member,
|
||||
):
|
||||
"""Test filtering members by active status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Add second user as inactive member
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -396,7 +375,7 @@ class TestGetOrganizationMembers:
|
||||
user_id=second_user.id,
|
||||
organization_id=test_org_with_user_member.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=False
|
||||
is_active=False,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -404,7 +383,7 @@ class TestGetOrganizationMembers:
|
||||
# Filter for active only
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}/members?is_active=true",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -416,31 +395,26 @@ class TestGetOrganizationMembers:
|
||||
|
||||
# ===== PUT /api/v1/organizations/{organization_id} =====
|
||||
|
||||
|
||||
class TestUpdateOrganization:
|
||||
"""Tests for PUT /api/v1/organizations/{organization_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_as_admin_success(
|
||||
self,
|
||||
client,
|
||||
async_test_user,
|
||||
test_org_with_user_admin
|
||||
self, client, async_test_user, test_org_with_user_admin
|
||||
):
|
||||
"""Test successfully updating organization as admin (covers lines 193-215)."""
|
||||
# Login as admin user
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
admin_token = login_response.json()["access_token"]
|
||||
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_user_admin.id}",
|
||||
json={
|
||||
"name": "Updated Admin Org",
|
||||
"description": "Updated description"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
json={"name": "Updated Admin Org", "description": "Updated description"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -450,23 +424,20 @@ class TestUpdateOrganization:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_as_owner_success(
|
||||
self,
|
||||
client,
|
||||
async_test_user,
|
||||
test_org_with_user_owner
|
||||
self, client, async_test_user, test_org_with_user_owner
|
||||
):
|
||||
"""Test successfully updating organization as owner."""
|
||||
# Login as owner user
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
owner_token = login_response.json()["access_token"]
|
||||
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_user_owner.id}",
|
||||
json={"name": "Updated Owner Org"},
|
||||
headers={"Authorization": f"Bearer {owner_token}"}
|
||||
headers={"Authorization": f"Bearer {owner_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -475,16 +446,13 @@ class TestUpdateOrganization:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_as_member_fails(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test updating organization as regular member fails."""
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}",
|
||||
json={"name": "Should Fail"},
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Should fail permission check (need admin or owner)
|
||||
@@ -492,15 +460,13 @@ class TestUpdateOrganization:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_not_found(
|
||||
self,
|
||||
client,
|
||||
test_org_with_user_admin
|
||||
self, client, test_org_with_user_admin
|
||||
):
|
||||
"""Test updating nonexistent organization returns 403 (permission check first)."""
|
||||
# Login as admin
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
admin_token = login_response.json()["access_token"]
|
||||
|
||||
@@ -508,7 +474,7 @@ class TestUpdateOrganization:
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{fake_org_id}",
|
||||
json={"name": "Updated"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
# Permission dependency checks admin role before endpoint logic
|
||||
@@ -520,6 +486,7 @@ class TestUpdateOrganization:
|
||||
|
||||
# ===== Authentication Tests =====
|
||||
|
||||
|
||||
class TestOrganizationAuthentication:
|
||||
"""Test authentication requirements for organization endpoints."""
|
||||
|
||||
@@ -548,14 +515,14 @@ class TestOrganizationAuthentication:
|
||||
"""Test unauthenticated access to update fails."""
|
||||
fake_id = uuid4()
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{fake_id}",
|
||||
json={"name": "Test"}
|
||||
f"/api/v1/organizations/{fake_id}", json={"name": "Test"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
# ===== Exception Handler Tests (Database Error Scenarios) =====
|
||||
|
||||
|
||||
class TestOrganizationExceptionHandlers:
|
||||
"""
|
||||
Test exception handlers in organization endpoints.
|
||||
@@ -566,86 +533,74 @@ class TestOrganizationExceptionHandlers:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_organizations_database_error(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test generic exception handler in get_my_organizations (covers lines 81-83)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get_user_organizations_with_details",
|
||||
side_effect=Exception("Database connection lost")
|
||||
side_effect=Exception("Database connection lost"),
|
||||
):
|
||||
# The exception handler logs and re-raises, so we expect the exception
|
||||
# to propagate (which proves the handler executed)
|
||||
with pytest.raises(Exception, match="Database connection lost"):
|
||||
await client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_database_error(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test generic exception handler in get_organization (covers lines 124-128)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get",
|
||||
side_effect=Exception("Database timeout")
|
||||
side_effect=Exception("Database timeout"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Database timeout"):
|
||||
await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_members_database_error(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test generic exception handler in get_organization_members (covers lines 170-172)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get_organization_members",
|
||||
side_effect=Exception("Connection pool exhausted")
|
||||
side_effect=Exception("Connection pool exhausted"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Connection pool exhausted"):
|
||||
await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}/members",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_database_error(
|
||||
self,
|
||||
client,
|
||||
async_test_user,
|
||||
test_org_with_user_admin
|
||||
self, client, async_test_user, test_org_with_user_admin
|
||||
):
|
||||
"""Test generic exception handler in update_organization (covers lines 217-221)."""
|
||||
# Login as admin user
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
admin_token = login_response.json()["access_token"]
|
||||
|
||||
with patch(
|
||||
"app.crud.organization.organization.get",
|
||||
return_value=test_org_with_user_admin
|
||||
return_value=test_org_with_user_admin,
|
||||
):
|
||||
with patch(
|
||||
"app.crud.organization.organization.update",
|
||||
side_effect=Exception("Write lock timeout")
|
||||
side_effect=Exception("Write lock timeout"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Write lock timeout"):
|
||||
await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_user_admin.id}",
|
||||
json={"name": "Should Fail"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
@@ -5,15 +5,17 @@ Tests for permission dependencies - CRITICAL SECURITY PATHS.
|
||||
These tests ensure superusers can bypass organization checks correctly,
|
||||
and that regular users are properly blocked.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -21,10 +23,7 @@ async def superuser_token(client, async_test_superuser):
|
||||
"""Get access token for superuser."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "superuser@example.com",
|
||||
"password": "SuperPassword123!"
|
||||
}
|
||||
json={"email": "superuser@example.com", "password": "SuperPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -35,10 +34,7 @@ async def regular_user_token(client, async_test_user):
|
||||
"""Get access token for regular user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -47,12 +43,12 @@ async def regular_user_token(client, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_no_members(async_test_db):
|
||||
"""Create a test organization with NO members."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="No Members Org",
|
||||
slug="no-members-org",
|
||||
description="Test org with no members"
|
||||
description="Test org with no members",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -63,12 +59,12 @@ async def test_org_no_members(async_test_db):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_member(async_test_db, async_test_user):
|
||||
"""Create a test organization with async_test_user as member (not admin)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Member Only Org",
|
||||
slug="member-only-org",
|
||||
description="Test org where user is just a member"
|
||||
description="Test org where user is just a member",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -79,7 +75,7 @@ async def test_org_with_member(async_test_db, async_test_user):
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -89,6 +85,7 @@ async def test_org_with_member(async_test_db, async_test_user):
|
||||
|
||||
# ===== CRITICAL SECURITY TESTS: Superuser Bypass =====
|
||||
|
||||
|
||||
class TestSuperuserBypass:
|
||||
"""
|
||||
CRITICAL: Test that superusers can bypass organization checks.
|
||||
@@ -99,10 +96,7 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_access_org_not_member_of(
|
||||
self,
|
||||
client,
|
||||
superuser_token,
|
||||
test_org_no_members
|
||||
self, client, superuser_token, test_org_no_members
|
||||
):
|
||||
"""
|
||||
CRITICAL: Superuser should bypass membership check (covers line 175).
|
||||
@@ -111,7 +105,7 @@ class TestSuperuserBypass:
|
||||
"""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_no_members.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Superuser should succeed even though they're not a member
|
||||
@@ -121,15 +115,12 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_access_org_not_member_of(
|
||||
self,
|
||||
client,
|
||||
regular_user_token,
|
||||
test_org_no_members
|
||||
self, client, regular_user_token, test_org_no_members
|
||||
):
|
||||
"""Regular user should be blocked from org they're not a member of."""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_no_members.id}",
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"}
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"},
|
||||
)
|
||||
|
||||
# Regular user should fail permission check
|
||||
@@ -137,10 +128,7 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_update_org_not_admin_of(
|
||||
self,
|
||||
client,
|
||||
superuser_token,
|
||||
test_org_no_members
|
||||
self, client, superuser_token, test_org_no_members
|
||||
):
|
||||
"""
|
||||
CRITICAL: Superuser should bypass admin check (covers line 99).
|
||||
@@ -150,7 +138,7 @@ class TestSuperuserBypass:
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_no_members.id}",
|
||||
json={"name": "Updated by Superuser"},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Superuser should succeed in updating org
|
||||
@@ -160,16 +148,13 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_member_cannot_update_org(
|
||||
self,
|
||||
client,
|
||||
regular_user_token,
|
||||
test_org_with_member
|
||||
self, client, regular_user_token, test_org_with_member
|
||||
):
|
||||
"""Regular member (not admin) should NOT be able to update org."""
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_member.id}",
|
||||
json={"name": "Should Fail"},
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"}
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"},
|
||||
)
|
||||
|
||||
# Member should fail - need admin or owner role
|
||||
@@ -177,15 +162,12 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_list_org_members_not_member_of(
|
||||
self,
|
||||
client,
|
||||
superuser_token,
|
||||
test_org_no_members
|
||||
self, client, superuser_token, test_org_no_members
|
||||
):
|
||||
"""CRITICAL: Superuser should bypass membership check to list members."""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_no_members.id}/members",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Superuser should succeed
|
||||
@@ -197,13 +179,14 @@ class TestSuperuserBypass:
|
||||
|
||||
# ===== Edge Cases and Security Tests =====
|
||||
|
||||
|
||||
class TestPermissionEdgeCases:
|
||||
"""Test edge cases in permission system."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inactive_user_blocked(self, client, async_test_db):
|
||||
"""Test that inactive users are blocked."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -213,7 +196,7 @@ class TestPermissionEdgeCases:
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
is_active=False # INACTIVE
|
||||
is_active=False, # INACTIVE
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
@@ -222,7 +205,7 @@ class TestPermissionEdgeCases:
|
||||
# But accessing protected endpoints should fail
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "inactive@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "inactive@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
# Login might fail for inactive users depending on auth implementation
|
||||
@@ -231,18 +214,18 @@ class TestPermissionEdgeCases:
|
||||
|
||||
# Try to access protected endpoint
|
||||
response = await client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
"/api/v1/users/me", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
# Should be blocked
|
||||
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
|
||||
assert response.status_code in [
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_organization_returns_403_not_404(
|
||||
self,
|
||||
client,
|
||||
regular_user_token
|
||||
self, client, regular_user_token
|
||||
):
|
||||
"""
|
||||
Test that accessing nonexistent org returns 403, not 404.
|
||||
@@ -254,7 +237,7 @@ class TestPermissionEdgeCases:
|
||||
fake_org_id = uuid4()
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{fake_org_id}",
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"}
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"},
|
||||
)
|
||||
|
||||
# Should get 403 (not a member), not 404 (doesn't exist)
|
||||
@@ -264,18 +247,16 @@ class TestPermissionEdgeCases:
|
||||
|
||||
# ===== Admin Role Tests =====
|
||||
|
||||
|
||||
class TestAdminRolePermissions:
|
||||
"""Test admin role can perform admin actions."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_admin(self, async_test_db, async_test_user):
|
||||
"""Create org where user is ADMIN."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Admin Org",
|
||||
slug="admin-org"
|
||||
)
|
||||
org = Organization(name="Admin Org", slug="admin-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -284,7 +265,7 @@ class TestAdminRolePermissions:
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -293,16 +274,13 @@ class TestAdminRolePermissions:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_update_org(
|
||||
self,
|
||||
client,
|
||||
regular_user_token,
|
||||
test_org_with_admin
|
||||
self, client, regular_user_token, test_org_with_admin
|
||||
):
|
||||
"""Admin should be able to update organization."""
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_admin.id}",
|
||||
json={"name": "Updated by Admin"},
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"}
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@@ -7,13 +7,13 @@ Critical security tests covering:
|
||||
|
||||
These tests prevent unauthorized access and privilege escalation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.organization import Organization
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class TestInactiveUserBlocking:
|
||||
@@ -29,11 +29,7 @@ class TestInactiveUserBlocking:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inactive_user_cannot_access_protected_endpoints(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User,
|
||||
user_token: str
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
|
||||
):
|
||||
"""
|
||||
Test that inactive users are blocked from protected endpoints.
|
||||
@@ -44,12 +40,11 @@ class TestInactiveUserBlocking:
|
||||
3. User tries to access protected endpoint with valid token
|
||||
4. System MUST reject (account inactive)
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Step 1: Verify user can access endpoint while active
|
||||
response = await client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
assert response.status_code == 200, "Active user should have access"
|
||||
|
||||
@@ -61,8 +56,7 @@ class TestInactiveUserBlocking:
|
||||
|
||||
# Step 3: User tries to access endpoint with same token
|
||||
response = await client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
# Step 4: System MUST reject (covers lines 52-57)
|
||||
@@ -75,18 +69,14 @@ class TestInactiveUserBlocking:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inactive_user_blocked_from_organization_endpoints(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User,
|
||||
user_token: str
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
|
||||
):
|
||||
"""
|
||||
Test that inactive users can't access organization endpoints.
|
||||
|
||||
Ensures the inactive check applies to ALL protected endpoints.
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Deactivate user
|
||||
async with SessionLocal() as session:
|
||||
@@ -97,7 +87,7 @@ class TestInactiveUserBlocking:
|
||||
# Try to list organizations
|
||||
response = await client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Must be blocked
|
||||
@@ -122,7 +112,7 @@ class TestSuperuserPrivilegeEscalation:
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_superuser: User,
|
||||
superuser_token: str
|
||||
superuser_token: str,
|
||||
):
|
||||
"""
|
||||
Test that superusers automatically get OWNER role in organizations.
|
||||
@@ -131,14 +121,11 @@ class TestSuperuserPrivilegeEscalation:
|
||||
Superusers can manage any organization without being explicitly added.
|
||||
This is for platform administration.
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Step 1: Create an organization (owned by someone else)
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = Organization(name="Test Organization", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -148,7 +135,7 @@ class TestSuperuserPrivilegeEscalation:
|
||||
# (They're not a member, but should auto-get OWNER role)
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Step 3: Should have access (covers lines 154-157)
|
||||
@@ -161,21 +148,18 @@ class TestSuperuserPrivilegeEscalation:
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_superuser: User,
|
||||
superuser_token: str
|
||||
superuser_token: str,
|
||||
):
|
||||
"""
|
||||
Test that superusers have full management access to all organizations.
|
||||
|
||||
Ensures the OWNER role privilege escalation works end-to-end.
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = Organization(name="Test Organization", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -185,34 +169,29 @@ class TestSuperuserPrivilegeEscalation:
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"name": "Updated Name"}
|
||||
json={"name": "Updated Name"},
|
||||
)
|
||||
|
||||
# Should succeed (superuser has OWNER privileges)
|
||||
assert response.status_code in [200, 404], "Superuser should be able to manage any org"
|
||||
assert response.status_code in [200, 404], (
|
||||
"Superuser should be able to manage any org"
|
||||
)
|
||||
# Note: Might be 404 if org endpoints require membership, but the role check passes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_does_not_get_owner_role(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User,
|
||||
user_token: str
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
|
||||
):
|
||||
"""
|
||||
Sanity check: Regular users don't get automatic OWNER role.
|
||||
|
||||
Ensures the superuser check is working correctly (line 154).
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = Organization(name="Test Organization", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -221,8 +200,10 @@ class TestSuperuserPrivilegeEscalation:
|
||||
# Regular user tries to access it (not a member)
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Should be denied (not a member, not a superuser)
|
||||
assert response.status_code in [403, 404], "Regular user shouldn't access non-member org"
|
||||
assert response.status_code in [403, 404], (
|
||||
"Regular user shouldn't access non-member org"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# tests/api/test_security_headers.py
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.main import app
|
||||
|
||||
@@ -11,8 +12,10 @@ def client():
|
||||
"""Create a FastAPI test client for the main app (module-scoped for speed)."""
|
||||
# Mock get_db to avoid database connection issues
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
|
||||
async def mock_session_generator():
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute = AsyncMock(return_value=None)
|
||||
mock_session.close = AsyncMock(return_value=None)
|
||||
@@ -77,8 +80,10 @@ class TestSecurityHeaders:
|
||||
"""Test that HSTS header is set in production (covers line 95)"""
|
||||
with patch("app.core.config.settings.ENVIRONMENT", "production"):
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
|
||||
async def mock_session_generator():
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute = AsyncMock(return_value=None)
|
||||
mock_session.close = AsyncMock(return_value=None)
|
||||
@@ -88,20 +93,26 @@ class TestSecurityHeaders:
|
||||
|
||||
# Need to reimport app to pick up the new settings
|
||||
from importlib import reload
|
||||
|
||||
import app.main
|
||||
|
||||
reload(app.main)
|
||||
test_client = TestClient(app.main.app)
|
||||
|
||||
response = test_client.get("/health")
|
||||
assert "Strict-Transport-Security" in response.headers
|
||||
assert "max-age=31536000" in response.headers["Strict-Transport-Security"]
|
||||
assert (
|
||||
"max-age=31536000" in response.headers["Strict-Transport-Security"]
|
||||
)
|
||||
|
||||
def test_csp_strict_mode(self):
|
||||
"""Test CSP strict mode (covers line 121)"""
|
||||
with patch("app.core.config.settings.CSP_MODE", "strict"):
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
|
||||
async def mock_session_generator():
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute = AsyncMock(return_value=None)
|
||||
mock_session.close = AsyncMock(return_value=None)
|
||||
@@ -110,7 +121,9 @@ class TestSecurityHeaders:
|
||||
mock_get_db.side_effect = lambda: mock_session_generator()
|
||||
|
||||
from importlib import reload
|
||||
|
||||
import app.main
|
||||
|
||||
reload(app.main)
|
||||
test_client = TestClient(app.main.app)
|
||||
|
||||
@@ -136,8 +149,10 @@ class TestRootEndpoint:
|
||||
def test_root_endpoint(self):
|
||||
"""Test root endpoint returns HTML (covers line 174)"""
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
|
||||
async def mock_session_generator():
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute = AsyncMock(return_value=None)
|
||||
mock_session.close = AsyncMock(return_value=None)
|
||||
|
||||
@@ -2,23 +2,23 @@
|
||||
"""
|
||||
Comprehensive tests for session management API endpoints.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi import status
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
# Disable rate limiting for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limit():
|
||||
"""Disable rate limiting for all tests in this module."""
|
||||
with patch('app.api.routes.sessions.limiter.enabled', False):
|
||||
with patch("app.api.routes.sessions.limiter.enabled", False):
|
||||
yield
|
||||
|
||||
|
||||
@@ -27,10 +27,7 @@ async def user_token(client, async_test_user):
|
||||
"""Create and return an access token for async_test_user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -39,7 +36,7 @@ async def user_token(client, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def async_test_user2(async_test_db):
|
||||
"""Create a second test user."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
@@ -49,7 +46,7 @@ async def async_test_user2(async_test_db):
|
||||
email="testuser2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User2"
|
||||
last_name="User2",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
await session.commit()
|
||||
@@ -61,9 +58,11 @@ class TestListMySessions:
|
||||
"""Tests for GET /api/v1/sessions/me endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_my_sessions_success(self, client, async_test_user, async_test_db, user_token):
|
||||
async def test_list_my_sessions_success(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test successfully listing user's active sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create some sessions for the user
|
||||
async with SessionLocal() as session:
|
||||
@@ -75,8 +74,8 @@ class TestListMySessions:
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0 (iPhone)",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
# Active session 2
|
||||
s2 = UserSession(
|
||||
@@ -86,8 +85,8 @@ class TestListMySessions:
|
||||
ip_address="192.168.1.101",
|
||||
user_agent="Mozilla/5.0 (Macintosh)",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
# Inactive session (should not appear)
|
||||
s3 = UserSession(
|
||||
@@ -97,16 +96,15 @@ class TestListMySessions:
|
||||
ip_address="192.168.1.102",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
session.add_all([s1, s2, s3])
|
||||
await session.commit()
|
||||
|
||||
# Make request
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -128,11 +126,12 @@ class TestListMySessions:
|
||||
assert data["sessions"][0]["is_current"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_my_sessions_with_login_session(self, client, async_test_user, user_token):
|
||||
async def test_list_my_sessions_with_login_session(
|
||||
self, client, async_test_user, user_token
|
||||
):
|
||||
"""Test listing sessions shows the login session."""
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -155,9 +154,11 @@ class TestRevokeSession:
|
||||
"""Tests for DELETE /api/v1/sessions/{session_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session_success(self, client, async_test_user, async_test_db, user_token):
|
||||
async def test_revoke_session_success(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test successfully revoking a session."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session to revoke
|
||||
async with SessionLocal() as session:
|
||||
@@ -168,8 +169,8 @@ class TestRevokeSession:
|
||||
ip_address="192.168.1.103",
|
||||
user_agent="Mozilla/5.0 (iPad)",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -179,7 +180,7 @@ class TestRevokeSession:
|
||||
# Revoke the session
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -191,6 +192,7 @@ class TestRevokeSession:
|
||||
# Verify session is deactivated
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.session import session as session_crud
|
||||
|
||||
revoked_session = await session_crud.get(session, id=str(session_id))
|
||||
assert revoked_session.is_active is False
|
||||
|
||||
@@ -200,7 +202,7 @@ class TestRevokeSession:
|
||||
fake_id = uuid4()
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{fake_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -222,7 +224,7 @@ class TestRevokeSession:
|
||||
self, client, async_test_user, async_test_user2, async_test_db, user_token
|
||||
):
|
||||
"""Test that users cannot revoke other users' sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session for user2
|
||||
async with SessionLocal() as session:
|
||||
@@ -233,8 +235,8 @@ class TestRevokeSession:
|
||||
ip_address="192.168.1.200",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(other_user_session)
|
||||
await session.commit()
|
||||
@@ -244,7 +246,7 @@ class TestRevokeSession:
|
||||
# Try to revoke it as user1
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
@@ -263,7 +265,7 @@ class TestCleanupExpiredSessions:
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test successfully cleaning up expired sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create expired and active sessions using CRUD to avoid greenlet issues
|
||||
from app.crud.session import session as session_crud
|
||||
@@ -277,8 +279,8 @@ class TestCleanupExpiredSessions:
|
||||
device_name="Expired 1",
|
||||
ip_address="192.168.1.201",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
e1 = await session_crud.create_session(db, obj_in=e1_data)
|
||||
e1.is_active = False
|
||||
@@ -291,8 +293,8 @@ class TestCleanupExpiredSessions:
|
||||
device_name="Expired 2",
|
||||
ip_address="192.168.1.202",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
)
|
||||
e2 = await session_crud.create_session(db, obj_in=e2_data)
|
||||
e2.is_active = False
|
||||
@@ -305,8 +307,8 @@ class TestCleanupExpiredSessions:
|
||||
device_name="Active",
|
||||
ip_address="192.168.1.203",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=a1_data)
|
||||
await db.commit()
|
||||
@@ -314,7 +316,7 @@ class TestCleanupExpiredSessions:
|
||||
# Cleanup expired sessions
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -329,7 +331,7 @@ class TestCleanupExpiredSessions:
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test cleanup when no sessions are expired."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create only active sessions using CRUD
|
||||
from app.crud.session import session as session_crud
|
||||
@@ -342,15 +344,15 @@ class TestCleanupExpiredSessions:
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.210",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=a1_data)
|
||||
await db.commit()
|
||||
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -369,13 +371,16 @@ class TestCleanupExpiredSessions:
|
||||
|
||||
# Additional tests for better coverage
|
||||
|
||||
|
||||
class TestSessionsAdditionalCases:
|
||||
"""Additional tests to improve sessions endpoint coverage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_pagination(self, client, async_test_user, async_test_db, user_token):
|
||||
async def test_list_sessions_pagination(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test listing sessions with pagination."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create multiple sessions
|
||||
async with SessionLocal() as session:
|
||||
@@ -389,15 +394,15 @@ class TestSessionsAdditionalCases:
|
||||
device_name=f"Device {i}",
|
||||
ip_address=f"192.168.1.{i}",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me?page=1&limit=3",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -410,16 +415,21 @@ class TestSessionsAdditionalCases:
|
||||
"""Test revoking session with invalid UUID."""
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/not-a-uuid",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Should return 422 for invalid UUID format
|
||||
assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_404_NOT_FOUND]
|
||||
assert response.status_code in [
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status.HTTP_404_NOT_FOUND,
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_with_mixed_states(self, client, async_test_user, async_test_db, user_token):
|
||||
async def test_cleanup_expired_sessions_with_mixed_states(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
@@ -432,8 +442,8 @@ class TestSessionsAdditionalCases:
|
||||
device_name="Expired Inactive",
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
e1 = await session_crud.create_session(db, obj_in=e1_data)
|
||||
e1.is_active = False
|
||||
@@ -446,8 +456,8 @@ class TestSessionsAdditionalCases:
|
||||
device_name="Expired Active",
|
||||
ip_address="192.168.1.101",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=e2_data)
|
||||
|
||||
@@ -455,7 +465,7 @@ class TestSessionsAdditionalCases:
|
||||
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -476,10 +486,12 @@ class TestSessionExceptionHandlers:
|
||||
from unittest.mock import patch
|
||||
|
||||
# Patch decode_token to raise an exception
|
||||
with patch('app.api.routes.sessions.decode_token', side_effect=Exception("Token decode error")):
|
||||
with patch(
|
||||
"app.api.routes.sessions.decode_token",
|
||||
side_effect=Exception("Token decode error"),
|
||||
):
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
# Should still succeed (exception is caught and ignored in try/except at line 77)
|
||||
@@ -489,12 +501,16 @@ class TestSessionExceptionHandlers:
|
||||
async def test_list_sessions_database_error(self, client, user_token):
|
||||
"""Test list_sessions handles database errors (covers lines 104-106)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.crud import session as session_module
|
||||
|
||||
with patch.object(session_module.session, 'get_user_sessions', side_effect=Exception("Database error")):
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
"get_user_sessions",
|
||||
side_effect=Exception("Database error"),
|
||||
):
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -503,18 +519,21 @@ class TestSessionExceptionHandlers:
|
||||
assert data["errors"][0]["message"] == "Failed to retrieve sessions"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session_database_error(self, client, user_token, async_test_db, async_test_user):
|
||||
async def test_revoke_session_database_error(
|
||||
self, client, user_token, async_test_db, async_test_user
|
||||
):
|
||||
"""Test revoke_session handles database errors (covers lines 181-183)."""
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud import session as session_module
|
||||
|
||||
# First create a session to revoke
|
||||
from app.crud.session import session as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as db:
|
||||
session_in = SessionCreate(
|
||||
@@ -523,17 +542,21 @@ class TestSessionExceptionHandlers:
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=60),
|
||||
)
|
||||
user_session = await session_crud.create_session(db, obj_in=session_in)
|
||||
session_id = user_session.id
|
||||
|
||||
# Mock the deactivate method to raise an exception
|
||||
with patch.object(session_module.session, 'deactivate', side_effect=Exception("Database connection lost")):
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
"deactivate",
|
||||
side_effect=Exception("Database connection lost"),
|
||||
):
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -544,12 +567,17 @@ class TestSessionExceptionHandlers:
|
||||
async def test_cleanup_expired_sessions_database_error(self, client, user_token):
|
||||
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.crud import session as session_module
|
||||
|
||||
with patch.object(session_module.session, 'cleanup_expired_for_user', side_effect=Exception("Cleanup failed")):
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
"cleanup_expired_for_user",
|
||||
side_effect=Exception("Cleanup failed"),
|
||||
):
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@@ -3,32 +3,29 @@
|
||||
Comprehensive tests for user management endpoints.
|
||||
These tests focus on finding potential bugs, not just coverage.
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
from fastapi import status
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserUpdate
|
||||
|
||||
|
||||
# Disable rate limiting for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limit():
|
||||
"""Disable rate limiting for all tests in this module."""
|
||||
with patch('app.api.routes.users.limiter.enabled', False):
|
||||
with patch('app.api.routes.auth.limiter.enabled', False):
|
||||
with patch("app.api.routes.users.limiter.enabled", False):
|
||||
with patch("app.api.routes.auth.limiter.enabled", False):
|
||||
yield
|
||||
|
||||
|
||||
async def get_auth_headers(client, email, password):
|
||||
"""Helper to get authentication headers."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password}
|
||||
"/api/v1/auth/login", json={"email": email, "password": password}
|
||||
)
|
||||
token = response.json()["access_token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
@@ -40,7 +37,9 @@ class TestListUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_superuser(self, client, async_test_superuser):
|
||||
"""Test listing users as superuser."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users", headers=headers)
|
||||
|
||||
@@ -53,16 +52,20 @@ class TestListUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_regular_user(self, client, async_test_user):
|
||||
"""Test that regular users cannot list users."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users", headers=headers)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_pagination(self, client, async_test_superuser, async_test_db):
|
||||
async def test_list_users_pagination(
|
||||
self, client, async_test_superuser, async_test_db
|
||||
):
|
||||
"""Test pagination works correctly."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -72,12 +75,14 @@ class TestListUsers:
|
||||
password_hash="hash",
|
||||
first_name=f"PagUser{i}",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
# Get first page
|
||||
response = await client.get("/api/v1/users?page=1&limit=5", headers=headers)
|
||||
@@ -88,9 +93,11 @@ class TestListUsers:
|
||||
assert data["pagination"]["total"] >= 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_filter_active(self, client, async_test_superuser, async_test_db):
|
||||
async def test_list_users_filter_active(
|
||||
self, client, async_test_superuser, async_test_db
|
||||
):
|
||||
"""Test filtering by active status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active and inactive users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -99,19 +106,21 @@ class TestListUsers:
|
||||
password_hash="hash",
|
||||
first_name="Active",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
inactive_user = User(
|
||||
email="inactivefilter@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Inactive",
|
||||
is_active=False,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add_all([active_user, inactive_user])
|
||||
await session.commit()
|
||||
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
# Filter for active users
|
||||
response = await client.get("/api/v1/users?is_active=true", headers=headers)
|
||||
@@ -130,9 +139,13 @@ class TestListUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_sort_by_email(self, client, async_test_superuser):
|
||||
"""Test sorting users by email."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users?sort_by=email&sort_order=asc", headers=headers)
|
||||
response = await client.get(
|
||||
"/api/v1/users?sort_by=email&sort_order=asc", headers=headers
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
emails = [u["email"] for u in data["data"]]
|
||||
@@ -154,7 +167,9 @@ class TestGetCurrentUserProfile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_own_profile(self, client, async_test_user):
|
||||
"""Test getting own profile."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users/me", headers=headers)
|
||||
|
||||
@@ -176,12 +191,14 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_own_profile(self, client, async_test_user):
|
||||
"""Test updating own profile."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers=headers,
|
||||
json={"first_name": "Updated", "last_name": "Name"}
|
||||
json={"first_name": "Updated", "last_name": "Name"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -192,12 +209,12 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_phone_number(self, client, async_test_user, test_db):
|
||||
"""Test updating phone number with validation."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers=headers,
|
||||
json={"phone_number": "+19876543210"}
|
||||
"/api/v1/users/me", headers=headers, json={"phone_number": "+19876543210"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -207,12 +224,12 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_invalid_phone(self, client, async_test_user):
|
||||
"""Test that invalid phone numbers are rejected."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers=headers,
|
||||
json={"phone_number": "invalid"}
|
||||
"/api/v1/users/me", headers=headers, json={"phone_number": "invalid"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -220,14 +237,16 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_elevate_to_superuser(self, client, async_test_user):
|
||||
"""Test that users cannot make themselves superuser."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
# Note: is_superuser is now in UserUpdate schema with explicit validation
|
||||
# This tests that Pydantic rejects the attempt at the schema level
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers=headers,
|
||||
json={"first_name": "Test", "is_superuser": True}
|
||||
json={"first_name": "Test", "is_superuser": True},
|
||||
)
|
||||
|
||||
# Pydantic validation should reject this at the schema level
|
||||
@@ -242,10 +261,7 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_no_auth(self, client):
|
||||
"""Test that unauthenticated requests are rejected."""
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
json={"first_name": "Hacker"}
|
||||
)
|
||||
response = await client.patch("/api/v1/users/me", json={"first_name": "Hacker"})
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
# Note: Removed test_update_profile_unexpected_error - see comment above
|
||||
@@ -257,16 +273,22 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_own_profile_by_id(self, client, async_test_user):
|
||||
"""Test getting own profile by ID."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
|
||||
response = await client.get(
|
||||
f"/api/v1/users/{async_test_user.id}", headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["email"] == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_other_user_as_regular_user(self, client, async_test_user, test_db):
|
||||
async def test_get_other_user_as_regular_user(
|
||||
self, client, async_test_user, test_db
|
||||
):
|
||||
"""Test that regular users cannot view other profiles."""
|
||||
# Create another user
|
||||
other_user = User(
|
||||
@@ -274,24 +296,32 @@ class TestGetUserById:
|
||||
password_hash="hash",
|
||||
first_name="Other",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
test_db.add(other_user)
|
||||
test_db.commit()
|
||||
test_db.refresh(other_user)
|
||||
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get(f"/api/v1/users/{other_user.id}", headers=headers)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_other_user_as_superuser(self, client, async_test_superuser, async_test_user):
|
||||
async def test_get_other_user_as_superuser(
|
||||
self, client, async_test_superuser, async_test_user
|
||||
):
|
||||
"""Test that superusers can view other profiles."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
|
||||
response = await client.get(
|
||||
f"/api/v1/users/{async_test_user.id}", headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
@@ -300,7 +330,9 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_user(self, client, async_test_superuser):
|
||||
"""Test getting non-existent user."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
response = await client.get(f"/api/v1/users/{fake_id}", headers=headers)
|
||||
@@ -310,7 +342,9 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_invalid_uuid(self, client, async_test_superuser):
|
||||
"""Test getting user with invalid UUID format."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users/not-a-uuid", headers=headers)
|
||||
|
||||
@@ -323,12 +357,14 @@ class TestUpdateUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_own_profile_by_id(self, client, async_test_user, test_db):
|
||||
"""Test updating own profile by ID."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "SelfUpdated"}
|
||||
json={"first_name": "SelfUpdated"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -336,7 +372,9 @@ class TestUpdateUserById:
|
||||
assert data["first_name"] == "SelfUpdated"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_other_user_as_regular_user(self, client, async_test_user, test_db):
|
||||
async def test_update_other_user_as_regular_user(
|
||||
self, client, async_test_user, test_db
|
||||
):
|
||||
"""Test that regular users cannot update other profiles."""
|
||||
# Create another user
|
||||
other_user = User(
|
||||
@@ -344,18 +382,20 @@ class TestUpdateUserById:
|
||||
password_hash="hash",
|
||||
first_name="Other",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
test_db.add(other_user)
|
||||
test_db.commit()
|
||||
test_db.refresh(other_user)
|
||||
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{other_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "Hacked"}
|
||||
json={"first_name": "Hacked"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
@@ -365,14 +405,18 @@ class TestUpdateUserById:
|
||||
assert other_user.first_name == "Other"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_other_user_as_superuser(self, client, async_test_superuser, async_test_user, test_db):
|
||||
async def test_update_other_user_as_superuser(
|
||||
self, client, async_test_superuser, async_test_user, test_db
|
||||
):
|
||||
"""Test that superusers can update other profiles."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "AdminUpdated"}
|
||||
json={"first_name": "AdminUpdated"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -380,16 +424,20 @@ class TestUpdateUserById:
|
||||
assert data["first_name"] == "AdminUpdated"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_modify_superuser_status(self, client, async_test_user):
|
||||
async def test_regular_user_cannot_modify_superuser_status(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test that regular users cannot change superuser status even if they try."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
# is_superuser not in UserUpdate schema, so it gets ignored by Pydantic
|
||||
# Just verify the user stays the same
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "Test"}
|
||||
json={"first_name": "Test"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -397,14 +445,18 @@ class TestUpdateUserById:
|
||||
assert data["is_superuser"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_update_users(self, client, async_test_superuser, async_test_user, test_db):
|
||||
async def test_superuser_can_update_users(
|
||||
self, client, async_test_superuser, async_test_user, test_db
|
||||
):
|
||||
"""Test that superusers can update other users."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "AdminChanged", "is_active": False}
|
||||
json={"first_name": "AdminChanged", "is_active": False},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -415,13 +467,13 @@ class TestUpdateUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_nonexistent_user(self, client, async_test_superuser):
|
||||
"""Test updating non-existent user."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{fake_id}",
|
||||
headers=headers,
|
||||
json={"first_name": "Ghost"}
|
||||
f"/api/v1/users/{fake_id}", headers=headers, json={"first_name": "Ghost"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -435,15 +487,17 @@ class TestChangePassword:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_success(self, client, async_test_user, test_db):
|
||||
"""Test successful password change."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers=headers,
|
||||
json={
|
||||
"current_password": "TestPassword123!",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
"new_password": "NewPassword123!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -453,25 +507,24 @@ class TestChangePassword:
|
||||
# Verify can login with new password
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "NewPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "NewPassword123!"},
|
||||
)
|
||||
assert login_response.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_wrong_current(self, client, async_test_user):
|
||||
"""Test that wrong current password is rejected."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers=headers,
|
||||
json={
|
||||
"current_password": "WrongPassword123",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
"new_password": "NewPassword123!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
@@ -479,15 +532,14 @@ class TestChangePassword:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_weak_new_password(self, client, async_test_user):
|
||||
"""Test that weak new passwords are rejected."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers=headers,
|
||||
json={
|
||||
"current_password": "TestPassword123!",
|
||||
"new_password": "weak"
|
||||
}
|
||||
json={"current_password": "TestPassword123!", "new_password": "weak"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -499,8 +551,8 @@ class TestChangePassword:
|
||||
"/api/v1/users/me/password",
|
||||
json={
|
||||
"current_password": "TestPassword123!",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
"new_password": "NewPassword123!",
|
||||
},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -511,9 +563,11 @@ class TestDeleteUser:
|
||||
"""Tests for DELETE /users/{user_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_as_superuser(self, client, async_test_superuser, async_test_db):
|
||||
async def test_delete_user_as_superuser(
|
||||
self, client, async_test_superuser, async_test_db
|
||||
):
|
||||
"""Test deleting a user as superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a user to delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -522,14 +576,16 @@ class TestDeleteUser:
|
||||
password_hash="hash",
|
||||
first_name="Delete",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add(user_to_delete)
|
||||
await session.commit()
|
||||
await session.refresh(user_to_delete)
|
||||
user_id = user_to_delete.id
|
||||
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{user_id}", headers=headers)
|
||||
|
||||
@@ -540,6 +596,7 @@ class TestDeleteUser:
|
||||
# Verify user is soft-deleted (has deleted_at timestamp)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
deleted_user = result.scalar_one_or_none()
|
||||
assert deleted_user.deleted_at is not None
|
||||
@@ -547,9 +604,13 @@ class TestDeleteUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_delete_self(self, client, async_test_superuser):
|
||||
"""Test that users cannot delete their own account."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{async_test_superuser.id}", headers=headers)
|
||||
response = await client.delete(
|
||||
f"/api/v1/users/{async_test_superuser.id}", headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@@ -562,22 +623,28 @@ class TestDeleteUser:
|
||||
password_hash="hash",
|
||||
first_name="Protected",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
test_db.add(other_user)
|
||||
test_db.commit()
|
||||
test_db.refresh(other_user)
|
||||
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{other_user.id}", headers=headers)
|
||||
response = await client.delete(
|
||||
f"/api/v1/users/{other_user.id}", headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_user(self, client, async_test_superuser):
|
||||
"""Test deleting non-existent user."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{fake_id}", headers=headers)
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
"""
|
||||
Tests for user routes.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -13,10 +15,7 @@ async def superuser_token(client, async_test_superuser):
|
||||
"""Get access token for superuser."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "superuser@example.com",
|
||||
"password": "SuperPassword123!"
|
||||
}
|
||||
json={"email": "superuser@example.com", "password": "SuperPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -27,10 +26,7 @@ async def user_token(client, async_test_user):
|
||||
"""Get access token for regular user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -43,8 +39,7 @@ class TestListUsers:
|
||||
async def test_list_users_success(self, client, superuser_token):
|
||||
"""Test listing users successfully (covers lines 87-100)."""
|
||||
response = await client.get(
|
||||
"/api/v1/users",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
"/api/v1/users", headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -58,7 +53,7 @@ class TestListUsers:
|
||||
"""Test listing users with is_superuser filter (covers line 74)."""
|
||||
response = await client.get(
|
||||
"/api/v1/users?is_superuser=true",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -73,8 +68,7 @@ class TestGetCurrentUser:
|
||||
async def test_get_current_user_success(self, client, async_test_user, user_token):
|
||||
"""Test getting current user profile."""
|
||||
response = await client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -92,7 +86,7 @@ class TestUpdateCurrentUser:
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"first_name": "UpdatedName"}
|
||||
json={"first_name": "UpdatedName"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -104,12 +98,14 @@ class TestUpdateCurrentUser:
|
||||
"""Test database error handling during update (covers lines 162-169)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("DB error")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error")
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -118,7 +114,7 @@ class TestUpdateCurrentUser:
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"is_superuser": True}
|
||||
json={"is_superuser": True},
|
||||
)
|
||||
|
||||
# Pydantic validation should reject this at the schema level
|
||||
@@ -137,12 +133,15 @@ class TestUpdateCurrentUser:
|
||||
"""Test ValueError handling during update (covers lines 165-166)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid value")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update",
|
||||
side_effect=ValueError("Invalid value"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
|
||||
@@ -154,7 +153,7 @@ class TestGetUser:
|
||||
"""Test getting user by ID."""
|
||||
response = await client.get(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -167,7 +166,7 @@ class TestGetUser:
|
||||
fake_id = uuid4()
|
||||
response = await client.get(
|
||||
f"/api/v1/users/{fake_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -183,30 +182,34 @@ class TestUpdateUserById:
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{fake_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(self, client, async_test_user, user_token):
|
||||
async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(
|
||||
self, client, async_test_user, user_token
|
||||
):
|
||||
"""Test non-superuser cannot modify superuser status (Pydantic validation)."""
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"is_superuser": True}
|
||||
json={"is_superuser": True},
|
||||
)
|
||||
|
||||
# Pydantic validation should reject this at the schema level
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_by_id_success(self, client, async_test_user, superuser_token):
|
||||
async def test_update_user_by_id_success(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test updating user successfully (covers lines 276-278)."""
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"first_name": "SuperUpdated"}
|
||||
json={"first_name": "SuperUpdated"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -214,29 +217,37 @@ class TestUpdateUserById:
|
||||
assert data["first_name"] == "SuperUpdated"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_by_id_value_error(self, client, async_test_user, superuser_token):
|
||||
async def test_update_user_by_id_value_error(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test ValueError handling (covers lines 280-281)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid")
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_by_id_unexpected_error(self, client, async_test_user, superuser_token):
|
||||
async def test_update_user_by_id_unexpected_error(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test unexpected error handling (covers lines 283-284)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("Unexpected")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected")
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
|
||||
@@ -246,18 +257,18 @@ class TestChangePassword:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_success(self, client, async_test_db):
|
||||
"""Test changing password successfully."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a fresh user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user import User
|
||||
|
||||
new_user = User(
|
||||
email="changepass@example.com",
|
||||
password_hash=get_password_hash("OldPassword123!"),
|
||||
first_name="Change",
|
||||
last_name="Pass"
|
||||
last_name="Pass",
|
||||
)
|
||||
session.add(new_user)
|
||||
await session.commit()
|
||||
@@ -265,10 +276,7 @@ class TestChangePassword:
|
||||
# Login
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "changepass@example.com",
|
||||
"password": "OldPassword123!"
|
||||
}
|
||||
json={"email": "changepass@example.com", "password": "OldPassword123!"},
|
||||
)
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
@@ -278,8 +286,8 @@ class TestChangePassword:
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={
|
||||
"current_password": "OldPassword123!",
|
||||
"new_password": "NewPassword456!"
|
||||
}
|
||||
"new_password": "NewPassword456!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -289,10 +297,7 @@ class TestChangePassword:
|
||||
# Verify new password works
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "changepass@example.com",
|
||||
"password": "NewPassword456!"
|
||||
}
|
||||
json={"email": "changepass@example.com", "password": "NewPassword456!"},
|
||||
)
|
||||
assert login_response.status_code == status.HTTP_200_OK
|
||||
|
||||
@@ -306,7 +311,7 @@ class TestDeleteUserById:
|
||||
fake_id = uuid4()
|
||||
response = await client.delete(
|
||||
f"/api/v1/users/{fake_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -314,18 +319,18 @@ class TestDeleteUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_success(self, client, async_test_db, superuser_token):
|
||||
"""Test deleting user successfully (covers lines 383-388)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a user to delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user import User
|
||||
|
||||
user_to_delete = User(
|
||||
email=f"delete{uuid4().hex[:8]}@example.com",
|
||||
password_hash=get_password_hash("Password123!"),
|
||||
first_name="Delete",
|
||||
last_name="Me"
|
||||
last_name="Me",
|
||||
)
|
||||
session.add(user_to_delete)
|
||||
await session.commit()
|
||||
@@ -334,7 +339,7 @@ class TestDeleteUserById:
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/users/{user_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -342,25 +347,35 @@ class TestDeleteUserById:
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_value_error(self, client, async_test_user, superuser_token):
|
||||
async def test_delete_user_value_error(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test ValueError handling during delete (covers lines 390-391)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=ValueError("Cannot delete")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.soft_delete",
|
||||
side_effect=ValueError("Cannot delete"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
await client.delete(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_unexpected_error(self, client, async_test_user, superuser_token):
|
||||
async def test_delete_user_unexpected_error(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test unexpected error handling during delete (covers lines 393-394)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=Exception("Unexpected")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.soft_delete",
|
||||
side_effect=Exception("Unexpected"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.delete(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
@@ -1,28 +1,32 @@
|
||||
# tests/conftest.py
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
# Set IS_TEST environment variable BEFORE importing app
|
||||
# This prevents the scheduler from starting during tests
|
||||
os.environ["IS_TEST"] = "True"
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
|
||||
from app.core.database import get_db
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.utils.test_utils import (
|
||||
setup_async_test_db,
|
||||
setup_test_db,
|
||||
teardown_async_test_db,
|
||||
teardown_test_db,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""
|
||||
Creates a fresh SQLite in-memory database for each test function.
|
||||
|
||||
|
||||
Yields a SQLAlchemy session that can be used for testing.
|
||||
"""
|
||||
# Set up the database
|
||||
@@ -46,6 +50,7 @@ async def async_test_db():
|
||||
yield test_engine, AsyncTestingSessionLocal
|
||||
await teardown_async_test_db(test_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_create_data():
|
||||
return {
|
||||
@@ -55,7 +60,7 @@ def user_create_data():
|
||||
"last_name": "User",
|
||||
"phone_number": "+1234567890",
|
||||
"is_superuser": False,
|
||||
"preferences": None
|
||||
"preferences": None,
|
||||
}
|
||||
|
||||
|
||||
@@ -102,7 +107,7 @@ async def client(async_test_db):
|
||||
|
||||
This overrides the get_db dependency to use the test database.
|
||||
"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async def override_get_db():
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -176,7 +181,7 @@ async def async_test_user(async_test_db):
|
||||
|
||||
Password: TestPassword123
|
||||
"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
@@ -202,7 +207,7 @@ async def async_test_superuser(async_test_db):
|
||||
|
||||
Password: SuperPassword123
|
||||
"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
@@ -256,4 +261,4 @@ async def superuser_token(client, async_test_superuser):
|
||||
)
|
||||
assert response.status_code == 200, f"Login failed: {response.text}"
|
||||
tokens = response.json()
|
||||
return tokens["access_token"]
|
||||
return tokens["access_token"]
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# tests/core/test_auth.py
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from jose import jwt
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
TokenMissingClaimError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
get_password_hash,
|
||||
get_token_data,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
TokenMissingClaimError
|
||||
verify_password,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -58,15 +58,13 @@ class TestTokenCreation:
|
||||
custom_claims = {
|
||||
"email": "test@example.com",
|
||||
"first_name": "Test",
|
||||
"is_superuser": True
|
||||
"is_superuser": True,
|
||||
}
|
||||
token = create_access_token(subject=user_id, claims=custom_claims)
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
@@ -87,9 +85,7 @@ class TestTokenCreation:
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
@@ -105,23 +101,18 @@ class TestTokenCreation:
|
||||
expires = timedelta(minutes=5)
|
||||
|
||||
# Create token with specific expiration
|
||||
token = create_access_token(
|
||||
subject=user_id,
|
||||
expires_delta=expires
|
||||
)
|
||||
token = create_access_token(subject=user_id, expires_delta=expires)
|
||||
|
||||
# Decode token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Get actual expiration time from token
|
||||
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
expiration = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||
|
||||
# Calculate expected expiration (approximately)
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
expected_expiration = now + expires
|
||||
|
||||
# Difference should be small (less than 1 second)
|
||||
@@ -148,7 +139,7 @@ class TestTokenDecoding:
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# Create a token that's already expired by directly manipulating the payload
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
expired_time = now - timedelta(hours=1) # 1 hour in the past
|
||||
|
||||
# Create the expired token manually
|
||||
@@ -157,13 +148,11 @@ class TestTokenDecoding:
|
||||
"exp": int(expired_time.timestamp()), # Set expiration in the past
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Attempting to decode should raise TokenExpiredError
|
||||
@@ -180,20 +169,16 @@ class TestTokenDecoding:
|
||||
def test_decode_token_with_missing_sub(self):
|
||||
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
|
||||
# Create a token without a subject
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
payload = {
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
# No 'sub' claim
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
decode_token(token)
|
||||
@@ -211,20 +196,16 @@ class TestTokenDecoding:
|
||||
"""Test that a token with invalid payload structure raises TokenInvalidError"""
|
||||
# Create a token with an invalid payload structure - missing 'sub' which is required
|
||||
# but including 'exp' to avoid the expiration check
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
payload = {
|
||||
# Missing "sub" field which is required
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"invalid_field": "test"
|
||||
"invalid_field": "test",
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
# Should raise TokenMissingClaimError due to missing 'sub'
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
@@ -236,11 +217,7 @@ class TestTokenDecoding:
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
# Should raise TokenInvalidError due to ValidationError
|
||||
with pytest.raises(TokenInvalidError):
|
||||
@@ -249,12 +226,9 @@ class TestTokenDecoding:
|
||||
def test_get_token_data(self):
|
||||
"""Test extracting TokenData from a token"""
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
subject=str(user_id),
|
||||
claims={"is_superuser": True}
|
||||
)
|
||||
token = create_access_token(subject=str(user_id), claims={"is_superuser": True})
|
||||
|
||||
token_data = get_token_data(token)
|
||||
|
||||
assert token_data.user_id == user_id
|
||||
assert token_data.is_superuser is True
|
||||
assert token_data.is_superuser is True
|
||||
|
||||
@@ -8,11 +8,11 @@ Critical security tests covering:
|
||||
|
||||
These tests cover critical security vulnerabilities that could be exploited.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from jose import jwt
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app.core.auth import decode_token, create_access_token, TokenInvalidError
|
||||
from app.core.auth import TokenInvalidError, create_access_token, decode_token
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
@@ -46,13 +46,14 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
"""
|
||||
# Create a payload that would normally be valid (using timestamps)
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600, # 1 hour from now
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
# Craft a malicious token with "alg: none"
|
||||
@@ -61,13 +62,13 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
import json
|
||||
|
||||
header = {"alg": "none", "typ": "JWT"}
|
||||
header_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(header).encode()
|
||||
).decode().rstrip("=")
|
||||
header_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
payload_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(payload).encode()
|
||||
).decode().rstrip("=")
|
||||
payload_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
# Token with no signature (algorithm "none")
|
||||
malicious_token = f"{header_encoded}.{payload_encoded}."
|
||||
@@ -85,22 +86,17 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
# Try uppercase "NONE"
|
||||
header = {"alg": "NONE", "typ": "JWT"}
|
||||
header_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(header).encode()
|
||||
).decode().rstrip("=")
|
||||
header_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
payload_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(payload).encode()
|
||||
).decode().rstrip("=")
|
||||
payload_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
malicious_token = f"{header_encoded}.{payload_encoded}."
|
||||
|
||||
@@ -121,15 +117,11 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
before our defensive checks at line 212. This is good for security!
|
||||
"""
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
# Create a valid payload
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
# Encode with wrong algorithm (RS256 instead of HS256)
|
||||
# This simulates an attacker trying algorithm substitution
|
||||
@@ -137,9 +129,7 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
|
||||
try:
|
||||
malicious_token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=wrong_algorithm
|
||||
payload, settings.SECRET_KEY, algorithm=wrong_algorithm
|
||||
)
|
||||
|
||||
# Should reject the token (library catches mismatch)
|
||||
@@ -156,21 +146,15 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
Prevents algorithm downgrade/upgrade attacks.
|
||||
"""
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
# Create token with HS384 instead of HS256
|
||||
try:
|
||||
malicious_token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm="HS384"
|
||||
payload, settings.SECRET_KEY, algorithm="HS384"
|
||||
)
|
||||
|
||||
with pytest.raises(TokenInvalidError):
|
||||
@@ -223,20 +207,15 @@ class TestJWTSecurityEdgeCases:
|
||||
|
||||
# Create token without "alg" in header
|
||||
header = {"typ": "JWT"} # Missing "alg"
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
header_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(header).encode()
|
||||
).decode().rstrip("=")
|
||||
header_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
payload_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(payload).encode()
|
||||
).decode().rstrip("=")
|
||||
payload_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
malicious_token = f"{header_encoded}.{payload_encoded}.fake_signature"
|
||||
|
||||
@@ -253,15 +232,20 @@ class TestJWTSecurityEdgeCases:
|
||||
"""Test token with malformed JSON in payload."""
|
||||
import base64
|
||||
|
||||
header = {"alg": "HS256", "typ": "JWT"}
|
||||
header_encoded = base64.urlsafe_b64encode(
|
||||
b'{"alg":"HS256","typ":"JWT"}'
|
||||
).decode().rstrip("=")
|
||||
header_encoded = (
|
||||
base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}')
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
# Invalid JSON (missing closing brace)
|
||||
invalid_payload_encoded = base64.urlsafe_b64encode(
|
||||
b'{"sub":"user123"' # Invalid JSON
|
||||
).decode().rstrip("=")
|
||||
invalid_payload_encoded = (
|
||||
base64.urlsafe_b64encode(
|
||||
b'{"sub":"user123"' # Invalid JSON
|
||||
)
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
malicious_token = f"{header_encoded}.{invalid_payload_encoded}.fake_sig"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# tests/core/test_config.py
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import Settings
|
||||
|
||||
|
||||
@@ -22,11 +23,15 @@ class TestSecretKeyValidation:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(SECRET_KEY=default_key, ENVIRONMENT="production")
|
||||
|
||||
assert "must be set to a secure random value in production" in str(exc_info.value)
|
||||
assert "must be set to a secure random value in production" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
def test_default_secret_key_in_development_allows_with_warning(self, caplog):
|
||||
"""Test that default SECRET_KEY in development is allowed but warns"""
|
||||
settings = Settings(SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development")
|
||||
settings = Settings(
|
||||
SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development"
|
||||
)
|
||||
|
||||
assert settings.SECRET_KEY == "your_secret_key_here" + "x" * 14
|
||||
# Note: The warning happens during validation, which we've seen works
|
||||
@@ -44,19 +49,13 @@ class TestSuperuserPasswordValidation:
|
||||
|
||||
def test_none_password_accepted(self):
|
||||
"""Test that None password is accepted (optional field)"""
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD=None
|
||||
)
|
||||
settings = Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=None)
|
||||
assert settings.FIRST_SUPERUSER_PASSWORD is None
|
||||
|
||||
def test_password_too_short_raises_error(self):
|
||||
"""Test that password shorter than 12 characters raises error"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="Short1"
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="Short1")
|
||||
|
||||
assert "must be at least 12 characters" in str(exc_info.value)
|
||||
|
||||
@@ -64,14 +63,11 @@ class TestSuperuserPasswordValidation:
|
||||
"""Test that common weak passwords are rejected"""
|
||||
# Test with the exact weak passwords from the validator
|
||||
# These are in the weak_passwords set and should be rejected
|
||||
weak_passwords = ['123456789012'] # Exactly 12 chars, in the weak set
|
||||
weak_passwords = ["123456789012"] # Exactly 12 chars, in the weak set
|
||||
|
||||
for weak_pwd in weak_passwords:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD=weak_pwd
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=weak_pwd)
|
||||
# Should get "too weak" message
|
||||
error_str = str(exc_info.value)
|
||||
assert "too weak" in error_str
|
||||
@@ -79,30 +75,21 @@ class TestSuperuserPasswordValidation:
|
||||
def test_password_without_lowercase_rejected(self):
|
||||
"""Test that password without lowercase is rejected"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123"
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123")
|
||||
|
||||
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
|
||||
|
||||
def test_password_without_uppercase_rejected(self):
|
||||
"""Test that password without uppercase is rejected"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="alllowercase123"
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="alllowercase123")
|
||||
|
||||
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
|
||||
|
||||
def test_password_without_digit_rejected(self):
|
||||
"""Test that password without digit is rejected"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="NoDigitsHere"
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="NoDigitsHere")
|
||||
|
||||
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
|
||||
|
||||
@@ -110,8 +97,7 @@ class TestSuperuserPasswordValidation:
|
||||
"""Test that strong password is accepted"""
|
||||
strong_password = "StrongPassword123!"
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD=strong_password
|
||||
SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=strong_password
|
||||
)
|
||||
|
||||
assert settings.FIRST_SUPERUSER_PASSWORD == strong_password
|
||||
@@ -150,7 +136,7 @@ class TestDatabaseConfiguration:
|
||||
POSTGRES_HOST="testhost",
|
||||
POSTGRES_PORT="5432",
|
||||
POSTGRES_DB="testdb",
|
||||
DATABASE_URL=None # Don't use explicit URL
|
||||
DATABASE_URL=None, # Don't use explicit URL
|
||||
)
|
||||
|
||||
expected_url = "postgresql://testuser:testpass@testhost:5432/testdb"
|
||||
@@ -159,10 +145,7 @@ class TestDatabaseConfiguration:
|
||||
def test_explicit_database_url_used_when_set(self):
|
||||
"""Test that explicit DATABASE_URL is used when provided"""
|
||||
explicit_url = "postgresql://explicit:pass@host:5432/db"
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
DATABASE_URL=explicit_url
|
||||
)
|
||||
settings = Settings(SECRET_KEY="a" * 32, DATABASE_URL=explicit_url)
|
||||
|
||||
assert settings.database_url == explicit_url
|
||||
|
||||
|
||||
@@ -6,8 +6,10 @@ Critical security tests covering:
|
||||
|
||||
These tests prevent security misconfigurations.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
@@ -43,6 +45,7 @@ class TestSecretKeySecurityValidation:
|
||||
# Import Settings class fresh (to pick up new env var)
|
||||
# The ValidationError should be raised during reload when Settings() is instantiated
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
# Reload will raise ValidationError because Settings() is instantiated at module level
|
||||
@@ -58,7 +61,9 @@ class TestSecretKeySecurityValidation:
|
||||
|
||||
# Reload config to restore original settings
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
def test_secret_key_exactly_32_characters_accepted(self):
|
||||
@@ -75,7 +80,9 @@ class TestSecretKeySecurityValidation:
|
||||
os.environ["SECRET_KEY"] = key_32
|
||||
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Should work
|
||||
@@ -89,7 +96,9 @@ class TestSecretKeySecurityValidation:
|
||||
os.environ.pop("SECRET_KEY", None)
|
||||
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
def test_secret_key_long_enough_accepted(self):
|
||||
@@ -106,7 +115,9 @@ class TestSecretKeySecurityValidation:
|
||||
os.environ["SECRET_KEY"] = key_64
|
||||
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Should work
|
||||
@@ -120,7 +131,9 @@ class TestSecretKeySecurityValidation:
|
||||
os.environ.pop("SECRET_KEY", None)
|
||||
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
def test_default_secret_key_meets_requirements(self):
|
||||
@@ -132,4 +145,6 @@ class TestSecretKeySecurityValidation:
|
||||
from app.core.config import settings
|
||||
|
||||
# Current settings should have valid SECRET_KEY
|
||||
assert len(settings.SECRET_KEY) >= 32, "Default SECRET_KEY must be at least 32 chars"
|
||||
assert len(settings.SECRET_KEY) >= 32, (
|
||||
"Default SECRET_KEY must be at least 32 chars"
|
||||
)
|
||||
|
||||
@@ -9,18 +9,19 @@ Covers:
|
||||
- init_async_db
|
||||
- close_async_db
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import (
|
||||
get_async_database_url,
|
||||
get_db,
|
||||
async_transaction_scope,
|
||||
check_async_database_health,
|
||||
init_async_db,
|
||||
close_async_db,
|
||||
get_async_database_url,
|
||||
get_db,
|
||||
init_async_db,
|
||||
)
|
||||
|
||||
|
||||
@@ -88,12 +89,13 @@ class TestAsyncTransactionScope:
|
||||
async def test_transaction_scope_commits_on_success(self, async_test_db):
|
||||
"""Test that successful operations are committed (covers line 138)."""
|
||||
# Mock the transaction scope to use test database
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
with patch('app.core.database.SessionLocal', SessionLocal):
|
||||
with patch("app.core.database.SessionLocal", SessionLocal):
|
||||
async with async_transaction_scope() as db:
|
||||
# Execute a simple query to verify transaction works
|
||||
from sqlalchemy import text
|
||||
|
||||
result = await db.execute(text("SELECT 1"))
|
||||
assert result is not None
|
||||
# Transaction should be committed (covers line 138 debug log)
|
||||
@@ -101,12 +103,13 @@ class TestAsyncTransactionScope:
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_scope_rollback_on_error(self, async_test_db):
|
||||
"""Test that transaction rolls back on exception."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
with patch('app.core.database.SessionLocal', SessionLocal):
|
||||
with patch("app.core.database.SessionLocal", SessionLocal):
|
||||
with pytest.raises(RuntimeError, match="Test error"):
|
||||
async with async_transaction_scope() as db:
|
||||
from sqlalchemy import text
|
||||
|
||||
await db.execute(text("SELECT 1"))
|
||||
raise RuntimeError("Test error")
|
||||
|
||||
@@ -117,9 +120,9 @@ class TestCheckAsyncDatabaseHealth:
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_health_check_success(self, async_test_db):
|
||||
"""Test health check returns True on success (covers line 156)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
with patch('app.core.database.SessionLocal', SessionLocal):
|
||||
with patch("app.core.database.SessionLocal", SessionLocal):
|
||||
result = await check_async_database_health()
|
||||
assert result is True
|
||||
|
||||
@@ -127,7 +130,7 @@ class TestCheckAsyncDatabaseHealth:
|
||||
async def test_database_health_check_failure(self):
|
||||
"""Test health check returns False on database error."""
|
||||
# Mock async_transaction_scope to raise an error
|
||||
with patch('app.core.database.async_transaction_scope') as mock_scope:
|
||||
with patch("app.core.database.async_transaction_scope") as mock_scope:
|
||||
mock_scope.side_effect = Exception("Database connection failed")
|
||||
|
||||
result = await check_async_database_health()
|
||||
@@ -140,10 +143,10 @@ class TestInitAsyncDb:
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_async_db_creates_tables(self, async_test_db):
|
||||
"""Test init_async_db creates tables (covers lines 174-176)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
test_engine, _SessionLocal = async_test_db
|
||||
|
||||
# Mock the engine to use test engine
|
||||
with patch('app.core.database.engine', test_engine):
|
||||
with patch("app.core.database.engine", test_engine):
|
||||
await init_async_db()
|
||||
# If no exception, tables were created successfully
|
||||
|
||||
@@ -155,7 +158,6 @@ class TestCloseAsyncDb:
|
||||
async def test_close_async_db_disposes_engine(self):
|
||||
"""Test close_async_db disposes engine (covers lines 185-186)."""
|
||||
# Create a fresh engine to test closing
|
||||
from app.core.database import engine
|
||||
|
||||
# Close connections
|
||||
await close_async_db()
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
"""
|
||||
Comprehensive tests for CRUDBase class covering all error paths and edge cases.
|
||||
"""
|
||||
|
||||
from datetime import UTC
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from uuid import uuid4, UUID
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
@@ -19,7 +21,7 @@ class TestCRUDBaseGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_invalid_uuid_string(self, async_test_db):
|
||||
"""Test get with invalid UUID string returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.get(session, id="invalid-uuid")
|
||||
@@ -28,7 +30,7 @@ class TestCRUDBaseGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_invalid_uuid_type(self, async_test_db):
|
||||
"""Test get with invalid UUID type returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.get(session, id=12345) # int instead of UUID
|
||||
@@ -37,7 +39,7 @@ class TestCRUDBaseGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_uuid_object(self, async_test_db, async_test_user):
|
||||
"""Test get with UUID object instead of string."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Pass UUID object directly
|
||||
@@ -48,26 +50,24 @@ class TestCRUDBaseGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_options(self, async_test_db, async_test_user):
|
||||
"""Test get with eager loading options (tests lines 76-78)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Test that options parameter is accepted and doesn't error
|
||||
# We pass an empty list which still tests the code path
|
||||
result = await user_crud.get(
|
||||
session,
|
||||
id=str(async_test_user.id),
|
||||
options=[]
|
||||
session, id=str(async_test_user.id), options=[]
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_database_error(self, async_test_db):
|
||||
"""Test get handles database errors properly."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock execute to raise an exception
|
||||
with patch.object(session, 'execute', side_effect=Exception("DB error")):
|
||||
with patch.object(session, "execute", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.get(session, id=str(uuid4()))
|
||||
|
||||
@@ -78,7 +78,7 @@ class TestCRUDBaseGetMulti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_negative_skip(self, async_test_db):
|
||||
"""Test get_multi with negative skip raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
@@ -87,7 +87,7 @@ class TestCRUDBaseGetMulti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_negative_limit(self, async_test_db):
|
||||
"""Test get_multi with negative limit raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
@@ -96,7 +96,7 @@ class TestCRUDBaseGetMulti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_limit_too_large(self, async_test_db):
|
||||
"""Test get_multi with limit > 1000 raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
@@ -105,25 +105,20 @@ class TestCRUDBaseGetMulti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_options(self, async_test_db, async_test_user):
|
||||
"""Test get_multi with eager loading options (tests lines 118-120)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Test that options parameter is accepted
|
||||
results = await user_crud.get_multi(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
options=[]
|
||||
)
|
||||
results = await user_crud.get_multi(session, skip=0, limit=10, options=[])
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_database_error(self, async_test_db):
|
||||
"""Test get_multi handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'execute', side_effect=Exception("DB error")):
|
||||
with patch.object(session, "execute", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.get_multi(session)
|
||||
|
||||
@@ -134,7 +129,7 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_unique_field(self, async_test_db, async_test_user):
|
||||
"""Test create with duplicate unique field raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Try to create user with duplicate email
|
||||
@@ -142,7 +137,7 @@ class TestCRUDBaseCreate:
|
||||
email=async_test_user.email, # Duplicate!
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="Duplicate"
|
||||
last_name="Duplicate",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
@@ -151,22 +146,23 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_integrity_error_non_duplicate(self, async_test_db):
|
||||
"""Test create with non-duplicate IntegrityError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock commit to raise IntegrityError without "unique" in message
|
||||
original_commit = session.commit
|
||||
|
||||
async def mock_commit():
|
||||
error = IntegrityError("statement", {}, Exception("foreign key violation"))
|
||||
error = IntegrityError(
|
||||
"statement", {}, Exception("foreign key violation")
|
||||
)
|
||||
raise error
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
@@ -175,15 +171,21 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_operational_error(self, async_test_db):
|
||||
"""Test create with OperationalError (user CRUD catches as generic Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection lost"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=OperationalError(
|
||||
"statement", {}, Exception("connection lost")
|
||||
),
|
||||
):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
@@ -193,15 +195,19 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_data_error(self, async_test_db):
|
||||
"""Test create with DataError (user CRUD catches as generic Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=DataError("statement", {}, Exception("invalid data"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=DataError("statement", {}, Exception("invalid data")),
|
||||
):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
@@ -211,15 +217,17 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unexpected_error(self, async_test_db):
|
||||
"""Test create with unexpected exception."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected error")):
|
||||
with patch.object(
|
||||
session, "commit", side_effect=RuntimeError("Unexpected error")
|
||||
):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||
@@ -232,16 +240,17 @@ class TestCRUDBaseUpdate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_duplicate_unique_field(self, async_test_db, async_test_user):
|
||||
"""Test update with duplicate unique field raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create another user
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
user2_data = UserCreate(
|
||||
email="user2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="User",
|
||||
last_name="Two"
|
||||
last_name="Two",
|
||||
)
|
||||
user2 = await user_crud.create(session, obj_in=user2_data)
|
||||
await session.commit()
|
||||
@@ -250,63 +259,89 @@ class TestCRUDBaseUpdate:
|
||||
async with SessionLocal() as session:
|
||||
user2_obj = await user_crud.get(session, id=str(user2.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("UNIQUE constraint failed"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=IntegrityError(
|
||||
"statement", {}, Exception("UNIQUE constraint failed")
|
||||
),
|
||||
):
|
||||
update_data = UserUpdate(email=async_test_user.email)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await user_crud.update(session, db_obj=user2_obj, obj_in=update_data)
|
||||
await user_crud.update(
|
||||
session, db_obj=user2_obj, obj_in=update_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test update with dict instead of schema."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
# Update with dict (tests lines 164-165)
|
||||
updated = await user_crud.update(
|
||||
session,
|
||||
db_obj=user,
|
||||
obj_in={"first_name": "UpdatedName"}
|
||||
session, db_obj=user, obj_in={"first_name": "UpdatedName"}
|
||||
)
|
||||
assert updated.first_name == "UpdatedName"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_integrity_error(self, async_test_db, async_test_user):
|
||||
"""Test update with IntegrityError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("constraint failed"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=IntegrityError(
|
||||
"statement", {}, Exception("constraint failed")
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_operational_error(self, async_test_db, async_test_user):
|
||||
"""Test update with OperationalError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection error"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=OperationalError(
|
||||
"statement", {}, Exception("connection error")
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test update with unexpected error."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
|
||||
with patch.object(
|
||||
session, "commit", side_effect=RuntimeError("Unexpected")
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
|
||||
|
||||
class TestCRUDBaseRemove:
|
||||
@@ -315,7 +350,7 @@ class TestCRUDBaseRemove:
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_invalid_uuid(self, async_test_db):
|
||||
"""Test remove with invalid UUID returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id="invalid-uuid")
|
||||
@@ -324,7 +359,7 @@ class TestCRUDBaseRemove:
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_with_uuid_object(self, async_test_db, async_test_user):
|
||||
"""Test remove with UUID object."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a user to delete
|
||||
async with SessionLocal() as session:
|
||||
@@ -332,7 +367,7 @@ class TestCRUDBaseRemove:
|
||||
email="todelete@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="To",
|
||||
last_name="Delete"
|
||||
last_name="Delete",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -347,7 +382,7 @@ class TestCRUDBaseRemove:
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_nonexistent(self, async_test_db):
|
||||
"""Test remove of nonexistent record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id=str(uuid4()))
|
||||
@@ -356,21 +391,31 @@ class TestCRUDBaseRemove:
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_integrity_error(self, async_test_db, async_test_user):
|
||||
"""Test remove with IntegrityError (foreign key constraint)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock delete to raise IntegrityError
|
||||
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("FOREIGN KEY constraint"))):
|
||||
with pytest.raises(ValueError, match="Cannot delete.*referenced by other records"):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=IntegrityError(
|
||||
"statement", {}, Exception("FOREIGN KEY constraint")
|
||||
),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot delete.*referenced by other records"
|
||||
):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test remove with unexpected error."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
|
||||
with patch.object(
|
||||
session, "commit", side_effect=RuntimeError("Unexpected")
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
@@ -381,10 +426,12 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
|
||||
"""Test get_multi_with_total basic functionality."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, total = await user_crud.get_multi_with_total(session, skip=0, limit=10)
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10
|
||||
)
|
||||
assert isinstance(items, list)
|
||||
assert isinstance(total, int)
|
||||
assert total >= 1 # At least the test user
|
||||
@@ -392,7 +439,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_skip(self, async_test_db):
|
||||
"""Test get_multi_with_total with negative skip raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
@@ -401,7 +448,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_limit(self, async_test_db):
|
||||
"""Test get_multi_with_total with negative limit raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
@@ -410,28 +457,34 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
|
||||
"""Test get_multi_with_total with limit > 1000 raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_total_with_filters(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_multi_with_total with filters."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
filters = {"email": async_test_user.email}
|
||||
items, total = await user_crud.get_multi_with_total(session, filters=filters)
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
session, filters=filters
|
||||
)
|
||||
assert total == 1
|
||||
assert len(items) == 1
|
||||
assert items[0].email == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_total_with_sorting_asc(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_multi_with_total with ascending sort."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
@@ -439,13 +492,13 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
email="aaa@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="AAA",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user_data2 = UserCreate(
|
||||
email="zzz@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="ZZZ",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
@@ -460,9 +513,11 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
assert items[0].email == "aaa@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_total_with_sorting_desc(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_multi_with_total with descending sort."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
@@ -470,20 +525,20 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
email="bbb@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="BBB",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user_data2 = UserCreate(
|
||||
email="ccc@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="CCC",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
items, _total = await user_crud.get_multi_with_total(
|
||||
session, sort_by="email", sort_order="desc", limit=1
|
||||
)
|
||||
assert len(items) == 1
|
||||
@@ -492,7 +547,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_pagination(self, async_test_db):
|
||||
"""Test get_multi_with_total pagination works correctly."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create minimal users for pagination test (3 instead of 5)
|
||||
async with SessionLocal() as session:
|
||||
@@ -501,19 +556,23 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
email=f"user{i}@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Get first page
|
||||
items1, total = await user_crud.get_multi_with_total(session, skip=0, limit=2)
|
||||
items1, total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=2
|
||||
)
|
||||
assert len(items1) == 2
|
||||
assert total >= 3
|
||||
|
||||
# Get second page
|
||||
items2, total2 = await user_crud.get_multi_with_total(session, skip=2, limit=2)
|
||||
items2, total2 = await user_crud.get_multi_with_total(
|
||||
session, skip=2, limit=2
|
||||
)
|
||||
assert len(items2) >= 1
|
||||
assert total2 == total
|
||||
|
||||
@@ -529,7 +588,7 @@ class TestCRUDBaseCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_basic(self, async_test_db, async_test_user):
|
||||
"""Test count returns correct number."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
count = await user_crud.count(session)
|
||||
@@ -539,7 +598,7 @@ class TestCRUDBaseCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_multiple_users(self, async_test_db, async_test_user):
|
||||
"""Test count with multiple users."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
@@ -549,13 +608,13 @@ class TestCRUDBaseCount:
|
||||
email="count1@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Count",
|
||||
last_name="One"
|
||||
last_name="One",
|
||||
)
|
||||
user_data2 = UserCreate(
|
||||
email="count2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Count",
|
||||
last_name="Two"
|
||||
last_name="Two",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
@@ -568,10 +627,10 @@ class TestCRUDBaseCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_database_error(self, async_test_db):
|
||||
"""Test count handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'execute', side_effect=Exception("DB error")):
|
||||
with patch.object(session, "execute", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.count(session)
|
||||
|
||||
@@ -582,7 +641,7 @@ class TestCRUDBaseExists:
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_true(self, async_test_db, async_test_user):
|
||||
"""Test exists returns True for existing record."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id=str(async_test_user.id))
|
||||
@@ -591,7 +650,7 @@ class TestCRUDBaseExists:
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_false(self, async_test_db):
|
||||
"""Test exists returns False for non-existent record."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id=str(uuid4()))
|
||||
@@ -600,7 +659,7 @@ class TestCRUDBaseExists:
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_invalid_uuid(self, async_test_db):
|
||||
"""Test exists returns False for invalid UUID."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id="invalid-uuid")
|
||||
@@ -613,7 +672,7 @@ class TestCRUDBaseSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_success(self, async_test_db):
|
||||
"""Test soft delete sets deleted_at timestamp."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a user to soft delete
|
||||
async with SessionLocal() as session:
|
||||
@@ -621,7 +680,7 @@ class TestCRUDBaseSoftDelete:
|
||||
email="softdelete@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Soft",
|
||||
last_name="Delete"
|
||||
last_name="Delete",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -636,7 +695,7 @@ class TestCRUDBaseSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_invalid_uuid(self, async_test_db):
|
||||
"""Test soft delete with invalid UUID returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.soft_delete(session, id="invalid-uuid")
|
||||
@@ -645,7 +704,7 @@ class TestCRUDBaseSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_nonexistent(self, async_test_db):
|
||||
"""Test soft delete of nonexistent record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.soft_delete(session, id=str(uuid4()))
|
||||
@@ -654,7 +713,7 @@ class TestCRUDBaseSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_with_uuid_object(self, async_test_db):
|
||||
"""Test soft delete with UUID object."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a user to soft delete
|
||||
async with SessionLocal() as session:
|
||||
@@ -662,7 +721,7 @@ class TestCRUDBaseSoftDelete:
|
||||
email="softdelete2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Soft",
|
||||
last_name="Delete2"
|
||||
last_name="Delete2",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -681,7 +740,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_success(self, async_test_db):
|
||||
"""Test restore clears deleted_at timestamp."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
@@ -689,7 +748,7 @@ class TestCRUDBaseRestore:
|
||||
email="restore@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -707,7 +766,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_invalid_uuid(self, async_test_db):
|
||||
"""Test restore with invalid UUID returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.restore(session, id="invalid-uuid")
|
||||
@@ -716,7 +775,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_nonexistent(self, async_test_db):
|
||||
"""Test restore of nonexistent record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.restore(session, id=str(uuid4()))
|
||||
@@ -725,7 +784,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_not_deleted(self, async_test_db, async_test_user):
|
||||
"""Test restore of non-deleted record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Try to restore a user that's not deleted
|
||||
@@ -735,7 +794,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_with_uuid_object(self, async_test_db):
|
||||
"""Test restore with UUID object."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
@@ -743,7 +802,7 @@ class TestCRUDBaseRestore:
|
||||
email="restore2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test2"
|
||||
last_name="Test2",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -765,7 +824,7 @@ class TestCRUDBasePaginationValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_skip(self, async_test_db):
|
||||
"""Test that negative skip raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
@@ -774,7 +833,7 @@ class TestCRUDBasePaginationValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_limit(self, async_test_db):
|
||||
"""Test that negative limit raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
@@ -783,23 +842,22 @@ class TestCRUDBasePaginationValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
|
||||
"""Test that limit > 1000 raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_total_with_filters(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test pagination with filters (covers lines 270-273)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
filters={"is_active": True}
|
||||
session, skip=0, limit=10, filters={"is_active": True}
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
assert total >= 0
|
||||
@@ -807,30 +865,22 @@ class TestCRUDBasePaginationValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db):
|
||||
"""Test pagination with descending sort (covers lines 283-284)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="created_at",
|
||||
sort_order="desc"
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db):
|
||||
"""Test pagination with ascending sort (covers lines 285-286)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="created_at",
|
||||
sort_order="asc"
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
|
||||
@@ -842,13 +892,15 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_model_without_deleted_at(self, async_test_db, async_test_user):
|
||||
async def test_soft_delete_model_without_deleted_at(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test soft_delete on Organization model (no deleted_at) raises ValueError (covers lines 342-343)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization (which doesn't have deleted_at)
|
||||
from app.models.organization import Organization
|
||||
from app.crud.organization import organization as org_crud
|
||||
from app.models.organization import Organization
|
||||
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
@@ -864,11 +916,11 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_model_without_deleted_at(self, async_test_db):
|
||||
"""Test restore on Organization model (no deleted_at) raises ValueError (covers lines 383-384)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization (which doesn't have deleted_at)
|
||||
from app.models.organization import Organization
|
||||
from app.crud.organization import organization as org_crud
|
||||
from app.models.organization import Organization
|
||||
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(name="Restore Test", slug="restore-test")
|
||||
@@ -889,14 +941,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_real_eager_loading_options(self, async_test_db, async_test_user):
|
||||
async def test_get_with_real_eager_loading_options(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get() with actual eager loading options (covers lines 77-78)."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
test_engine, SessionLocal = async_test_db
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session for the user
|
||||
from app.models.user_session import UserSession
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -905,8 +960,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
device_id="test-device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Test Agent",
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=60),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -917,7 +972,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
result = await session_crud.get(
|
||||
session,
|
||||
id=str(session_id),
|
||||
options=[joinedload(UserSession.user)] # Real option, not empty list
|
||||
options=[joinedload(UserSession.user)], # Real option, not empty list
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == session_id
|
||||
@@ -925,14 +980,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
assert result.user.email == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_real_eager_loading_options(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_real_eager_loading_options(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_multi() with actual eager loading options (covers lines 119-120)."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
test_engine, SessionLocal = async_test_db
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create multiple sessions for the user
|
||||
from app.models.user_session import UserSession
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
|
||||
async with SessionLocal() as session:
|
||||
for i in range(3):
|
||||
@@ -942,8 +1000,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
device_id=f"device-{i}",
|
||||
ip_address=f"192.168.1.{i}",
|
||||
user_agent=f"Agent {i}",
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=60),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -954,7 +1012,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
options=[joinedload(UserSession.user)] # Real option, not empty list
|
||||
options=[joinedload(UserSession.user)], # Real option, not empty list
|
||||
)
|
||||
assert len(results) >= 3
|
||||
# Verify we can access user without additional queries
|
||||
|
||||
@@ -3,13 +3,15 @@
|
||||
Comprehensive tests for base CRUD database failure scenarios.
|
||||
Tests exception handling, rollbacks, and error messages.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import DataError, OperationalError
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
class TestBaseCRUDCreateFailures:
|
||||
@@ -18,19 +20,24 @@ class TestBaseCRUDCreateFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_operational_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection lost", {}, Exception("DB connection failed"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
async def mock_commit():
|
||||
raise OperationalError(
|
||||
"Connection lost", {}, Exception("DB connection failed")
|
||||
)
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="operror@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
@@ -43,19 +50,22 @@ class TestBaseCRUDCreateFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_data_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise DataError("Invalid data type", {}, Exception("Data overflow"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="dataerror@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
@@ -67,19 +77,22 @@ class TestBaseCRUDCreateFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unexpected_exception_triggers_rollback(self, async_test_db):
|
||||
"""Test that unexpected exceptions trigger rollback and re-raise."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Unexpected database error")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="unexpected@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected database error"):
|
||||
@@ -94,7 +107,7 @@ class TestBaseCRUDUpdateFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_operational_error(self, async_test_db, async_test_user):
|
||||
"""Test update with OperationalError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -102,17 +115,21 @@ class TestBaseCRUDUpdateFailures:
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_data_error(self, async_test_db, async_test_user):
|
||||
"""Test update with DataError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -120,17 +137,21 @@ class TestBaseCRUDUpdateFailures:
|
||||
async def mock_commit():
|
||||
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test update with unexpected error."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -138,10 +159,14 @@ class TestBaseCRUDUpdateFailures:
|
||||
async def mock_commit():
|
||||
raise KeyError("Unexpected error")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(KeyError):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@@ -150,16 +175,21 @@ class TestBaseCRUDRemoveFailures:
|
||||
"""Test base CRUD remove method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_remove_unexpected_error_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test that unexpected errors in remove trigger rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Database write failed")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Database write failed"):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
@@ -172,16 +202,15 @@ class TestBaseCRUDGetMultiWithTotalFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_database_error(self, async_test_db):
|
||||
"""Test get_multi_with_total handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock execute to raise an error
|
||||
original_execute = session.execute
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query failed", {}, Exception("Database error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=10)
|
||||
|
||||
@@ -192,13 +221,14 @@ class TestBaseCRUDCountFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_database_error_propagates(self, async_test_db):
|
||||
"""Test count propagates database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Count failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.count(session)
|
||||
|
||||
@@ -207,16 +237,21 @@ class TestBaseCRUDSoftDeleteFailures:
|
||||
"""Test soft_delete method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_soft_delete_unexpected_error_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test soft_delete handles unexpected errors with rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Soft delete failed")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Soft delete failed"):
|
||||
await user_crud.soft_delete(session, id=str(async_test_user.id))
|
||||
|
||||
@@ -229,7 +264,7 @@ class TestBaseCRUDRestoreFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_unexpected_error_triggers_rollback(self, async_test_db):
|
||||
"""Test restore handles unexpected errors with rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# First create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
@@ -237,7 +272,7 @@ class TestBaseCRUDRestoreFailures:
|
||||
email="restore_test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -248,11 +283,14 @@ class TestBaseCRUDRestoreFailures:
|
||||
|
||||
# Now test restore failure
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Restore failed")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Restore failed"):
|
||||
await user_crud.restore(session, id=str(user_id))
|
||||
|
||||
@@ -265,13 +303,14 @@ class TestBaseCRUDGetFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_database_error_propagates(self, async_test_db):
|
||||
"""Test get propagates database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Get failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get(session, id=str(uuid4()))
|
||||
|
||||
@@ -282,12 +321,13 @@ class TestBaseCRUDGetMultiFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_database_error_propagates(self, async_test_db):
|
||||
"""Test get_multi propagates database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi(session, skip=0, limit=10)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,10 +2,12 @@
|
||||
"""
|
||||
Comprehensive tests for async session CRUD operations.
|
||||
"""
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
@@ -17,7 +19,7 @@ class TestGetByJti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_success(self, async_test_db, async_test_user):
|
||||
"""Test getting session by JTI."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -27,8 +29,8 @@ class TestGetByJti:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -41,7 +43,7 @@ class TestGetByJti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_not_found(self, async_test_db):
|
||||
"""Test getting non-existent JTI returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_by_jti(session, jti="nonexistent")
|
||||
@@ -54,7 +56,7 @@ class TestGetActiveByJti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_success(self, async_test_db, async_test_user):
|
||||
"""Test getting active session by JTI."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -64,8 +66,8 @@ class TestGetActiveByJti:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -78,7 +80,7 @@ class TestGetActiveByJti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user):
|
||||
"""Test getting inactive session by JTI returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -88,8 +90,8 @@ class TestGetActiveByJti:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -105,7 +107,7 @@ class TestGetUserSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
|
||||
"""Test getting only active user sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active = UserSession(
|
||||
@@ -115,8 +117,8 @@ class TestGetUserSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
inactive = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
@@ -125,17 +127,15 @@ class TestGetUserSessions:
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([active, inactive])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
active_only=True
|
||||
session, user_id=str(async_test_user.id), active_only=True
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].is_active is True
|
||||
@@ -143,7 +143,7 @@ class TestGetUserSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
|
||||
"""Test getting all user sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
@@ -154,17 +154,15 @@ class TestGetUserSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=i % 2 == 0,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
active_only=False
|
||||
session, user_id=str(async_test_user.id), active_only=False
|
||||
)
|
||||
assert len(results) == 3
|
||||
|
||||
@@ -175,7 +173,7 @@ class TestCreateSession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully creating a session_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
session_data = SessionCreate(
|
||||
@@ -185,10 +183,10 @@ class TestCreateSession:
|
||||
device_id="device_123",
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0",
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
location_city="San Francisco",
|
||||
location_country="USA"
|
||||
location_country="USA",
|
||||
)
|
||||
result = await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
@@ -204,7 +202,7 @@ class TestDeactivate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully deactivating a session_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -214,8 +212,8 @@ class TestDeactivate:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -229,7 +227,7 @@ class TestDeactivate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_not_found(self, async_test_db):
|
||||
"""Test deactivating non-existent session returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.deactivate(session, session_id=str(uuid4()))
|
||||
@@ -240,9 +238,11 @@ class TestDeactivateAllUserSessions:
|
||||
"""Tests for deactivate_all_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_user_sessions_success(self, async_test_db, async_test_user):
|
||||
async def test_deactivate_all_user_sessions_success(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test deactivating all user sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Create minimal sessions for test (2 instead of 5)
|
||||
@@ -254,16 +254,15 @@ class TestDeactivateAllUserSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.deactivate_all_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 2
|
||||
|
||||
@@ -274,7 +273,7 @@ class TestUpdateLastUsed:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_last_used_success(self, async_test_db, async_test_user):
|
||||
"""Test updating last_used_at timestamp."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -284,8 +283,8 @@ class TestUpdateLastUsed:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -303,7 +302,7 @@ class TestGetUserSessionCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user session count."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
@@ -314,28 +313,26 @@ class TestGetUserSessionCount:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.get_user_session_count(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_empty(self, async_test_db):
|
||||
"""Test getting session count for user with no sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.get_user_session_count(
|
||||
session,
|
||||
user_id=str(uuid4())
|
||||
session, user_id=str(uuid4())
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@@ -346,7 +343,7 @@ class TestUpdateRefreshToken:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token_success(self, async_test_db, async_test_user):
|
||||
"""Test updating refresh token JTI and expiration."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -356,26 +353,34 @@ class TestUpdateRefreshToken:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
new_jti = "new_jti_123"
|
||||
new_expires = datetime.now(timezone.utc) + timedelta(days=14)
|
||||
new_expires = datetime.now(UTC) + timedelta(days=14)
|
||||
|
||||
result = await session_crud.update_refresh_token(
|
||||
session,
|
||||
session=user_session,
|
||||
new_jti=new_jti,
|
||||
new_expires_at=new_expires
|
||||
new_expires_at=new_expires,
|
||||
)
|
||||
|
||||
assert result.refresh_token_jti == new_jti
|
||||
# Compare timestamps ignoring timezone info
|
||||
assert abs((result.expires_at.replace(tzinfo=None) - new_expires.replace(tzinfo=None)).total_seconds()) < 1
|
||||
assert (
|
||||
abs(
|
||||
(
|
||||
result.expires_at.replace(tzinfo=None)
|
||||
- new_expires.replace(tzinfo=None)
|
||||
).total_seconds()
|
||||
)
|
||||
< 1
|
||||
)
|
||||
|
||||
|
||||
class TestCleanupExpired:
|
||||
@@ -384,7 +389,7 @@ class TestCleanupExpired:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_success(self, async_test_db, async_test_user):
|
||||
"""Test cleaning up old expired inactive sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create old expired inactive session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -395,9 +400,9 @@ class TestCleanupExpired:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=35)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=5),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=35),
|
||||
created_at=datetime.now(UTC) - timedelta(days=35),
|
||||
)
|
||||
session.add(old_session)
|
||||
await session.commit()
|
||||
@@ -410,7 +415,7 @@ class TestCleanupExpired:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup keeps recent expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create recent expired inactive session (less than keep_days old)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -421,9 +426,9 @@ class TestCleanupExpired:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
session.add(recent_session)
|
||||
await session.commit()
|
||||
@@ -436,7 +441,7 @@ class TestCleanupExpired:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup does not delete active sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create old expired but ACTIVE session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -447,9 +452,9 @@ class TestCleanupExpired:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True, # Active
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=35)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=5),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=35),
|
||||
created_at=datetime.now(UTC) - timedelta(days=35),
|
||||
)
|
||||
session.add(active_session)
|
||||
await session.commit()
|
||||
@@ -464,9 +469,11 @@ class TestCleanupExpiredForUser:
|
||||
"""Tests for cleanup_expired_for_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_success(self, async_test_db, async_test_user):
|
||||
async def test_cleanup_expired_for_user_success(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test cleaning up expired sessions for specific user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create expired inactive session for user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -477,8 +484,8 @@ class TestCleanupExpiredForUser:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
session.add(expired_session)
|
||||
await session.commit()
|
||||
@@ -486,27 +493,27 @@ class TestCleanupExpiredForUser:
|
||||
# Cleanup for user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db):
|
||||
"""Test cleanup with invalid user UUID."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Invalid user ID format"):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id="not-a-valid-uuid"
|
||||
session, user_id="not-a-valid-uuid"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_keeps_active(self, async_test_db, async_test_user):
|
||||
async def test_cleanup_expired_for_user_keeps_active(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test that cleanup for user keeps active sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create expired but active session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -517,8 +524,8 @@ class TestCleanupExpiredForUser:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True, # Active
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
session.add(active_session)
|
||||
await session.commit()
|
||||
@@ -526,8 +533,7 @@ class TestCleanupExpiredForUser:
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 0 # Should not delete active sessions
|
||||
|
||||
@@ -536,9 +542,11 @@ class TestGetUserSessionsWithUser:
|
||||
"""Tests for get_user_sessions with eager loading."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_with_user_relationship(self, async_test_db, async_test_user):
|
||||
async def test_get_user_sessions_with_user_relationship(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test getting sessions with user relationship loaded."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -548,8 +556,8 @@ class TestGetUserSessionsWithUser:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -557,8 +565,6 @@ class TestGetUserSessionsWithUser:
|
||||
# Get with user relationship
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
with_user=True
|
||||
session, user_id=str(async_test_user.id), with_user=True
|
||||
)
|
||||
assert len(results) >= 1
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
"""
|
||||
Comprehensive tests for session CRUD database failure scenarios.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from sqlalchemy.exc import OperationalError, IntegrityError
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
@@ -19,13 +21,14 @@ class TestSessionCRUDGetByJtiFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_database_error(self, async_test_db):
|
||||
"""Test get_by_jti handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("DB connection lost", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_by_jti(session, jti="test_jti")
|
||||
|
||||
@@ -36,13 +39,14 @@ class TestSessionCRUDGetActiveByJtiFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_database_error(self, async_test_db):
|
||||
"""Test get_active_by_jti handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query timeout", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_active_by_jti(session, jti="test_jti")
|
||||
|
||||
@@ -51,19 +55,21 @@ class TestSessionCRUDGetUserSessionsFailures:
|
||||
"""Test get_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_database_error(self, async_test_db, async_test_user):
|
||||
async def test_get_user_sessions_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_user_sessions handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Database error", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
|
||||
@@ -71,24 +77,29 @@ class TestSessionCRUDCreateSessionFailures:
|
||||
"""Test create_session exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_create_session_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test create_session handles commit failures with rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Commit failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
@@ -97,24 +108,29 @@ class TestSessionCRUDCreateSessionFailures:
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_create_session_unexpected_error_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test create_session handles unexpected errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Unexpected error")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
@@ -127,9 +143,11 @@ class TestSessionCRUDDeactivateFailures:
|
||||
"""Test deactivate exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_deactivate_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test deactivate handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session first
|
||||
async with SessionLocal() as session:
|
||||
@@ -140,8 +158,8 @@ class TestSessionCRUDDeactivateFailures:
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -150,13 +168,18 @@ class TestSessionCRUDDeactivateFailures:
|
||||
|
||||
# Test deactivate failure
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Deactivate failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate(session, session_id=str(session_id))
|
||||
await session_crud.deactivate(
|
||||
session, session_id=str(session_id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@@ -165,20 +188,24 @@ class TestSessionCRUDDeactivateAllFailures:
|
||||
"""Test deactivate_all_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_deactivate_all_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test deactivate_all handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Bulk deactivate failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate_all_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
@@ -188,9 +215,11 @@ class TestSessionCRUDUpdateLastUsedFailures:
|
||||
"""Test update_last_used exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_last_used_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_update_last_used_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test update_last_used handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session
|
||||
async with SessionLocal() as session:
|
||||
@@ -201,8 +230,8 @@ class TestSessionCRUDUpdateLastUsedFailures:
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -211,15 +240,19 @@ class TestSessionCRUDUpdateLastUsedFailures:
|
||||
# Test update failure
|
||||
async with SessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user_session import UserSession as US
|
||||
|
||||
result = await session.execute(select(US).where(US.id == user_session.id))
|
||||
sess = result.scalar_one()
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Update failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_last_used(session, session=sess)
|
||||
|
||||
@@ -230,9 +263,11 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
"""Test update_refresh_token exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_update_refresh_token_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test update_refresh_token handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session
|
||||
async with SessionLocal() as session:
|
||||
@@ -243,8 +278,8 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -253,21 +288,25 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
# Test update failure
|
||||
async with SessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user_session import UserSession as US
|
||||
|
||||
result = await session.execute(select(US).where(US.id == user_session.id))
|
||||
sess = result.scalar_one()
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Token update failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_refresh_token(
|
||||
session,
|
||||
session=sess,
|
||||
new_jti=str(uuid4()),
|
||||
new_expires_at=datetime.now(timezone.utc) + timedelta(days=14)
|
||||
new_expires_at=datetime.now(UTC) + timedelta(days=14),
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
@@ -277,16 +316,21 @@ class TestSessionCRUDCleanupExpiredFailures:
|
||||
"""Test cleanup_expired exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_commit_failure_triggers_rollback(self, async_test_db):
|
||||
async def test_cleanup_expired_commit_failure_triggers_rollback(
|
||||
self, async_test_db
|
||||
):
|
||||
"""Test cleanup_expired handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Cleanup failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired(session, keep_days=30)
|
||||
|
||||
@@ -297,20 +341,24 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
|
||||
"""Test cleanup_expired_for_user exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test cleanup_expired_for_user handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("User cleanup failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
@@ -320,17 +368,19 @@ class TestSessionCRUDGetUserSessionCountFailures:
|
||||
"""Test get_user_session_count exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_database_error(self, async_test_db, async_test_user):
|
||||
async def test_get_user_session_count_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_user_session_count handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Count query failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_session_count(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
"""
|
||||
Comprehensive tests for async user CRUD operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
@@ -17,7 +15,7 @@ class TestGetByEmail:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user by email."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(session, email=async_test_user.email)
|
||||
@@ -28,10 +26,12 @@ class TestGetByEmail:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_not_found(self, async_test_db):
|
||||
"""Test getting non-existent email returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(session, email="nonexistent@example.com")
|
||||
result = await user_crud.get_by_email(
|
||||
session, email="nonexistent@example.com"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, async_test_db):
|
||||
"""Test successfully creating a user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
@@ -49,7 +49,7 @@ class TestCreate:
|
||||
password="SecurePass123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="+1234567890"
|
||||
phone_number="+1234567890",
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@@ -65,7 +65,7 @@ class TestCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_superuser_success(self, async_test_db):
|
||||
"""Test creating a superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
@@ -73,7 +73,7 @@ class TestCreate:
|
||||
password="SuperPass123!",
|
||||
first_name="Super",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
is_superuser=True,
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@@ -83,14 +83,14 @@ class TestCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
|
||||
"""Test creating user with duplicate email raises ValueError."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Duplicate email
|
||||
password="AnotherPass123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -105,16 +105,14 @@ class TestUpdate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
|
||||
"""Test updating basic user fields."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get fresh copy of user
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_data = UserUpdate(
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
phone_number="+9876543210"
|
||||
first_name="Updated", last_name="Name", phone_number="+9876543210"
|
||||
)
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
@@ -125,7 +123,7 @@ class TestUpdate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_password(self, async_test_db):
|
||||
"""Test updating user password."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a fresh user for this test
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -133,7 +131,7 @@ class TestUpdate:
|
||||
email="passwordtest@example.com",
|
||||
password="OldPassword123!",
|
||||
first_name="Pass",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -149,12 +147,14 @@ class TestUpdate:
|
||||
await session.refresh(result)
|
||||
assert result.password_hash != old_password_hash
|
||||
assert result.password_hash is not None
|
||||
assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed
|
||||
assert (
|
||||
"NewDifferentPassword123!" not in result.password_hash
|
||||
) # Should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test updating user with dictionary."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -171,13 +171,11 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
|
||||
"""Test basic pagination."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10
|
||||
session, skip=0, limit=10
|
||||
)
|
||||
assert total >= 1
|
||||
assert len(users) >= 1
|
||||
@@ -186,7 +184,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
|
||||
"""Test sorting in ascending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -195,17 +193,13 @@ class TestGetMultiWithTotal:
|
||||
email=f"sort{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="asc"
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="email", sort_order="asc"
|
||||
)
|
||||
|
||||
# Check if sorted (at least the test users)
|
||||
@@ -216,7 +210,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
|
||||
"""Test sorting in descending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -225,17 +219,13 @@ class TestGetMultiWithTotal:
|
||||
email=f"desc{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="desc"
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="email", sort_order="desc"
|
||||
)
|
||||
|
||||
# Check if sorted descending (at least the test users)
|
||||
@@ -246,7 +236,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_filtering(self, async_test_db):
|
||||
"""Test filtering by field."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active and inactive users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -254,7 +244,7 @@ class TestGetMultiWithTotal:
|
||||
email="active@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Active",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=active_user)
|
||||
|
||||
@@ -262,23 +252,18 @@ class TestGetMultiWithTotal:
|
||||
email="inactive@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
created_inactive = await user_crud.create(session, obj_in=inactive_user)
|
||||
|
||||
# Deactivate the user
|
||||
await user_crud.update(
|
||||
session,
|
||||
db_obj=created_inactive,
|
||||
obj_in={"is_active": False}
|
||||
session, db_obj=created_inactive, obj_in={"is_active": False}
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
filters={"is_active": True}
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=100, filters={"is_active": True}
|
||||
)
|
||||
|
||||
# All returned users should be active
|
||||
@@ -287,7 +272,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_search(self, async_test_db):
|
||||
"""Test search functionality."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user with unique name
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -295,16 +280,13 @@ class TestGetMultiWithTotal:
|
||||
email="searchable@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Searchable",
|
||||
last_name="UserName"
|
||||
last_name="UserName",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
search="Searchable"
|
||||
session, skip=0, limit=100, search="Searchable"
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
@@ -313,7 +295,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_pagination(self, async_test_db):
|
||||
"""Test pagination with skip and limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -322,23 +304,19 @@ class TestGetMultiWithTotal:
|
||||
email=f"page{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Page{i}",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get first page
|
||||
users_page1, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=2
|
||||
session, skip=0, limit=2
|
||||
)
|
||||
|
||||
# Get second page
|
||||
users_page2, total2 = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=2,
|
||||
limit=2
|
||||
session, skip=2, limit=2
|
||||
)
|
||||
|
||||
# Total should be same
|
||||
@@ -349,7 +327,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
|
||||
"""Test validation fails for negative skip."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -360,7 +338,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
|
||||
"""Test validation fails for negative limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -371,7 +349,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
|
||||
"""Test validation fails for limit > 1000."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -386,7 +364,7 @@ class TestBulkUpdateStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_success(self, async_test_db):
|
||||
"""Test bulk updating user status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
@@ -396,7 +374,7 @@ class TestBulkUpdateStatus:
|
||||
email=f"bulk{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Bulk{i}",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
@@ -404,9 +382,7 @@ class TestBulkUpdateStatus:
|
||||
# Bulk deactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
is_active=False
|
||||
session, user_ids=user_ids, is_active=False
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
@@ -419,20 +395,18 @@ class TestBulkUpdateStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_empty_list(self, async_test_db):
|
||||
"""Test bulk update with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=[],
|
||||
is_active=False
|
||||
session, user_ids=[], is_active=False
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_reactivate(self, async_test_db):
|
||||
"""Test bulk reactivating users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -440,7 +414,7 @@ class TestBulkUpdateStatus:
|
||||
email="reactivate@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Reactivate",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
# Deactivate
|
||||
@@ -450,9 +424,7 @@ class TestBulkUpdateStatus:
|
||||
# Reactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
is_active=True
|
||||
session, user_ids=[user_id], is_active=True
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
@@ -468,7 +440,7 @@ class TestBulkSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_success(self, async_test_db):
|
||||
"""Test bulk soft deleting users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
@@ -478,17 +450,14 @@ class TestBulkSoftDelete:
|
||||
email=f"delete{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Delete{i}",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids
|
||||
)
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are soft deleted
|
||||
@@ -501,7 +470,7 @@ class TestBulkSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
|
||||
"""Test bulk soft delete with excluded user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
@@ -511,7 +480,7 @@ class TestBulkSoftDelete:
|
||||
email=f"exclude{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Exclude{i}",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
@@ -520,9 +489,7 @@ class TestBulkSoftDelete:
|
||||
exclude_id = user_ids[0]
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
exclude_user_id=exclude_id
|
||||
session, user_ids=user_ids, exclude_user_id=exclude_id
|
||||
)
|
||||
assert count == 2 # Only 2 deleted
|
||||
|
||||
@@ -534,19 +501,16 @@ class TestBulkSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_empty_list(self, async_test_db):
|
||||
"""Test bulk delete with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[]
|
||||
)
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=[])
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
|
||||
"""Test bulk delete where all users are excluded."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -554,7 +518,7 @@ class TestBulkSoftDelete:
|
||||
email="onlyuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Only",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -562,16 +526,14 @@ class TestBulkSoftDelete:
|
||||
# Try to delete but exclude
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
exclude_user_id=user_id
|
||||
session, user_ids=[user_id], exclude_user_id=user_id
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
|
||||
"""Test bulk delete doesn't re-delete already deleted users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create and delete user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -579,7 +541,7 @@ class TestBulkSoftDelete:
|
||||
email="predeleted@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="PreDeleted",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -589,10 +551,7 @@ class TestBulkSoftDelete:
|
||||
|
||||
# Try to delete again
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id]
|
||||
)
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
|
||||
assert count == 0 # Already deleted
|
||||
|
||||
|
||||
@@ -602,7 +561,7 @@ class TestUtilityMethods:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_true(self, async_test_db, async_test_user):
|
||||
"""Test is_active returns True for active user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -611,14 +570,14 @@ class TestUtilityMethods:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_false(self, async_test_db):
|
||||
"""Test is_active returns False for inactive user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="inactive2@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
@@ -628,7 +587,7 @@ class TestUtilityMethods:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
|
||||
"""Test is_superuser returns True for superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_superuser.id))
|
||||
@@ -637,7 +596,7 @@ class TestUtilityMethods:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_false(self, async_test_db, async_test_user):
|
||||
"""Test is_superuser returns False for regular user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -654,42 +613,52 @@ class TestUserExceptionHandlers:
|
||||
async def test_get_by_email_database_error(self, async_test_db):
|
||||
"""Test get_by_email handles database errors (covers lines 30-32)."""
|
||||
from unittest.mock import patch
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with patch.object(session, 'execute', side_effect=Exception("Database query failed")):
|
||||
with patch.object(
|
||||
session, "execute", side_effect=Exception("Database query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Database query failed"):
|
||||
await user_crud.get_by_email(session, email="test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_database_error(self, async_test_db, async_test_user):
|
||||
async def test_bulk_update_status_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test bulk_update_status handles database errors (covers lines 205-208)."""
|
||||
from unittest.mock import patch, AsyncMock
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock execute to fail
|
||||
with patch.object(session, 'execute', side_effect=Exception("Bulk update failed")):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock):
|
||||
with patch.object(
|
||||
session, "execute", side_effect=Exception("Bulk update failed")
|
||||
):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(Exception, match="Bulk update failed"):
|
||||
await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=[async_test_user.id],
|
||||
is_active=False
|
||||
session, user_ids=[async_test_user.id], is_active=False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_database_error(self, async_test_db, async_test_user):
|
||||
async def test_bulk_soft_delete_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test bulk_soft_delete handles database errors (covers lines 257-260)."""
|
||||
from unittest.mock import patch, AsyncMock
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock execute to fail
|
||||
with patch.object(session, 'execute', side_effect=Exception("Bulk delete failed")):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock):
|
||||
with patch.object(
|
||||
session, "execute", side_effect=Exception("Bulk delete failed")
|
||||
):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(Exception, match="Bulk delete failed"):
|
||||
await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[async_test_user.id]
|
||||
session, user_ids=[async_test_user.id]
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# tests/models/test_user.py
|
||||
import uuid
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@@ -166,7 +168,6 @@ def test_user_required_fields(db_session):
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
|
||||
def test_user_defaults(db_session):
|
||||
"""Test that default values are correctly set."""
|
||||
# Arrange - Create a minimal user with only required fields
|
||||
@@ -210,22 +211,13 @@ def test_user_with_complex_json_preferences(db_session):
|
||||
"""Test storing and retrieving complex JSON preferences."""
|
||||
# Arrange - Create a user with nested JSON preferences
|
||||
complex_preferences = {
|
||||
"theme": {
|
||||
"mode": "dark",
|
||||
"colors": {
|
||||
"primary": "#333",
|
||||
"secondary": "#666"
|
||||
}
|
||||
},
|
||||
"theme": {"mode": "dark", "colors": {"primary": "#333", "secondary": "#666"}},
|
||||
"notifications": {
|
||||
"email": True,
|
||||
"sms": False,
|
||||
"push": {
|
||||
"enabled": True,
|
||||
"quiet_hours": [22, 7]
|
||||
}
|
||||
"push": {"enabled": True, "quiet_hours": [22, 7]},
|
||||
},
|
||||
"tags": ["important", "family", "events"]
|
||||
"tags": ["important", "family", "events"],
|
||||
}
|
||||
|
||||
user = User(
|
||||
@@ -234,16 +226,18 @@ def test_user_with_complex_json_preferences(db_session):
|
||||
password_hash="hashedpassword",
|
||||
first_name="Complex",
|
||||
last_name="JSON",
|
||||
preferences=complex_preferences
|
||||
preferences=complex_preferences,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
# Act - Retrieve the user
|
||||
retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first()
|
||||
retrieved_user = (
|
||||
db_session.query(User).filter_by(email="complex@example.com").first()
|
||||
)
|
||||
|
||||
# Assert - The complex JSON should be preserved
|
||||
assert retrieved_user.preferences == complex_preferences
|
||||
assert retrieved_user.preferences["theme"]["colors"]["primary"] == "#333"
|
||||
assert retrieved_user.preferences["notifications"]["push"]["quiet_hours"] == [22, 7]
|
||||
assert "important" in retrieved_user.preferences["tags"]
|
||||
assert "important" in retrieved_user.preferences["tags"]
|
||||
|
||||
@@ -5,6 +5,7 @@ Covers Pydantic validators for:
|
||||
- Slug validation (lines 26, 28, 30, 32, 62-70)
|
||||
- Name validation (lines 40, 77)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
@@ -20,19 +21,13 @@ class TestOrganizationBaseValidators:
|
||||
|
||||
def test_valid_organization_base(self):
|
||||
"""Test that valid data passes validation."""
|
||||
org = OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = OrganizationBase(name="Test Organization", slug="test-org")
|
||||
assert org.name == "Test Organization"
|
||||
assert org.slug == "test-org"
|
||||
|
||||
def test_slug_none_returns_none(self):
|
||||
"""Test that None slug is allowed (covers line 26)."""
|
||||
org = OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug=None
|
||||
)
|
||||
org = OrganizationBase(name="Test Organization", slug=None)
|
||||
assert org.slug is None
|
||||
|
||||
def test_slug_invalid_characters_rejected(self):
|
||||
@@ -40,57 +35,46 @@ class TestOrganizationBaseValidators:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="Test_Org!" # Uppercase and special chars
|
||||
slug="Test_Org!", # Uppercase and special chars
|
||||
)
|
||||
errors = exc_info.value.errors()
|
||||
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
def test_slug_starts_with_hyphen_rejected(self):
|
||||
"""Test slug starting with hyphen is rejected (covers line 30)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="-test-org"
|
||||
)
|
||||
OrganizationBase(name="Test Organization", slug="-test-org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
|
||||
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_slug_ends_with_hyphen_rejected(self):
|
||||
"""Test slug ending with hyphen is rejected (covers line 30)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="test-org-"
|
||||
)
|
||||
OrganizationBase(name="Test Organization", slug="test-org-")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
|
||||
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_slug_consecutive_hyphens_rejected(self):
|
||||
"""Test slug with consecutive hyphens is rejected (covers line 32)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="test--org"
|
||||
)
|
||||
OrganizationBase(name="Test Organization", slug="test--org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
def test_name_whitespace_only_rejected(self):
|
||||
"""Test whitespace-only name is rejected (covers line 40)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name=" ",
|
||||
slug="test-org"
|
||||
)
|
||||
OrganizationBase(name=" ", slug="test-org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name cannot be empty" in str(e['msg']) for e in errors)
|
||||
assert any("name cannot be empty" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_name_trimmed(self):
|
||||
"""Test that name is trimmed."""
|
||||
org = OrganizationBase(
|
||||
name=" Test Organization ",
|
||||
slug="test-org"
|
||||
)
|
||||
org = OrganizationBase(name=" Test Organization ", slug="test-org")
|
||||
assert org.name == "Test Organization"
|
||||
|
||||
|
||||
@@ -99,22 +83,18 @@ class TestOrganizationCreateValidators:
|
||||
|
||||
def test_valid_organization_create(self):
|
||||
"""Test that valid data passes validation."""
|
||||
org = OrganizationCreate(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = OrganizationCreate(name="Test Organization", slug="test-org")
|
||||
assert org.name == "Test Organization"
|
||||
assert org.slug == "test-org"
|
||||
|
||||
def test_slug_validation_inherited(self):
|
||||
"""Test that slug validation is inherited from base."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationCreate(
|
||||
name="Test",
|
||||
slug="Invalid_Slug!"
|
||||
)
|
||||
OrganizationCreate(name="Test", slug="Invalid_Slug!")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
|
||||
class TestOrganizationUpdateValidators:
|
||||
@@ -122,10 +102,7 @@ class TestOrganizationUpdateValidators:
|
||||
|
||||
def test_valid_organization_update(self):
|
||||
"""Test that valid update data passes validation."""
|
||||
org = OrganizationUpdate(
|
||||
name="Updated Name",
|
||||
slug="updated-slug"
|
||||
)
|
||||
org = OrganizationUpdate(name="Updated Name", slug="updated-slug")
|
||||
assert org.name == "Updated Name"
|
||||
assert org.slug == "updated-slug"
|
||||
|
||||
@@ -139,35 +116,39 @@ class TestOrganizationUpdateValidators:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(slug="Test_Org!")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
def test_update_slug_starts_with_hyphen_rejected(self):
|
||||
"""Test update slug starting with hyphen is rejected (covers line 66)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(slug="-test-org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
|
||||
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_update_slug_ends_with_hyphen_rejected(self):
|
||||
"""Test update slug ending with hyphen is rejected (covers line 66)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(slug="test-org-")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
|
||||
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_update_slug_consecutive_hyphens_rejected(self):
|
||||
"""Test update slug with consecutive hyphens is rejected (covers line 68)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(slug="test--org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
def test_update_name_whitespace_only_rejected(self):
|
||||
"""Test whitespace-only name in update is rejected (covers line 77)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(name=" ")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name cannot be empty" in str(e['msg']) for e in errors)
|
||||
assert any("name cannot be empty" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_update_name_none_allowed(self):
|
||||
"""Test that None name is allowed in update."""
|
||||
|
||||
@@ -1,80 +1,177 @@
|
||||
# tests/schemas/test_user_schemas.py
|
||||
import pytest
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.users import UserBase, UserCreate
|
||||
|
||||
|
||||
class TestPhoneNumberValidation:
|
||||
"""Tests for phone number validation in user schemas"""
|
||||
|
||||
def test_valid_swiss_numbers(self):
|
||||
"""Test valid Swiss phone numbers are accepted"""
|
||||
# International format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+41791234567",
|
||||
)
|
||||
assert user.phone_number == "+41791234567"
|
||||
|
||||
# Local format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0791234567",
|
||||
)
|
||||
assert user.phone_number == "0791234567"
|
||||
|
||||
# With formatting characters
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+41 79 123 45 67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="079 123 45 67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+41-79-123-45-67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="079-123-45-67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+41 (79) 123 45 67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="079 (123) 45 67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
|
||||
|
||||
def test_valid_italian_numbers(self):
|
||||
"""Test valid Italian phone numbers are accepted"""
|
||||
# International format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+393451234567",
|
||||
)
|
||||
assert user.phone_number == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+39345123456",
|
||||
)
|
||||
assert user.phone_number == "+39345123456"
|
||||
|
||||
# Local format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="03451234567",
|
||||
)
|
||||
assert user.phone_number == "03451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0345123456789",
|
||||
)
|
||||
assert user.phone_number == "0345123456789"
|
||||
|
||||
# With formatting characters
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+39 345 123 4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0345 123 4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+39-345-123-4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0345-123-4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+39 (345) 123 4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0345 (123) 4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
|
||||
|
||||
def test_none_phone_number(self):
|
||||
"""Test that None is accepted as a valid value (optional phone number)"""
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None)
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number=None,
|
||||
)
|
||||
assert user.phone_number is None
|
||||
|
||||
def test_invalid_phone_numbers(self):
|
||||
@@ -83,17 +180,14 @@ class TestPhoneNumberValidation:
|
||||
# Too short
|
||||
"+12",
|
||||
"012",
|
||||
|
||||
# Invalid characters
|
||||
"+41xyz123456",
|
||||
"079abc4567",
|
||||
"123-abc-7890",
|
||||
"+1(800)CALL-NOW",
|
||||
|
||||
# Completely invalid formats
|
||||
"++4412345678", # Double plus
|
||||
# Note: "()+41123456" becomes "+41123456" after cleaning, which is valid
|
||||
|
||||
# Empty string
|
||||
"",
|
||||
# Spaces only
|
||||
@@ -102,7 +196,12 @@ class TestPhoneNumberValidation:
|
||||
|
||||
for number in invalid_numbers:
|
||||
with pytest.raises(ValidationError):
|
||||
UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number)
|
||||
UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number=number,
|
||||
)
|
||||
|
||||
def test_phone_validation_in_user_create(self):
|
||||
"""Test that phone validation also works in UserCreate schema"""
|
||||
@@ -112,7 +211,7 @@ class TestPhoneNumberValidation:
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
password="Password123!",
|
||||
phone_number="+41791234567"
|
||||
phone_number="+41791234567",
|
||||
)
|
||||
assert user.phone_number == "+41791234567"
|
||||
|
||||
@@ -123,5 +222,5 @@ class TestPhoneNumberValidation:
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
password="Password123!",
|
||||
phone_number="invalid-number"
|
||||
)
|
||||
phone_number="invalid-number",
|
||||
)
|
||||
|
||||
@@ -7,12 +7,13 @@ Covers all edge cases in validation functions:
|
||||
- validate_email_format (line 148)
|
||||
- validate_slug (lines 170-183)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.validators import (
|
||||
validate_email_format,
|
||||
validate_password_strength,
|
||||
validate_phone_number,
|
||||
validate_email_format,
|
||||
validate_slug,
|
||||
)
|
||||
|
||||
@@ -108,12 +109,14 @@ class TestPhoneNumberValidator:
|
||||
validate_phone_number("+123456789012345") # 15 digits after +
|
||||
|
||||
def test_multiple_plus_symbols_rejected(self):
|
||||
"""Test phone number with multiple + symbols.
|
||||
r"""Test phone number with multiple + symbols.
|
||||
|
||||
Note: Line 115 is defensive code - the regex check at line 110 catches this first.
|
||||
The regex ^(?:\+[0-9]{8,14}|0[0-9]{8,14})$ only allows + at the start.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="must start with \\+ or 0 followed by 8-14 digits"):
|
||||
with pytest.raises(
|
||||
ValueError, match="must start with \\+ or 0 followed by 8-14 digits"
|
||||
):
|
||||
validate_phone_number("+1234+5678901")
|
||||
|
||||
def test_non_digit_after_prefix_rejected(self):
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
# tests/services/test_auth_service.py
|
||||
import uuid
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
|
||||
from app.core.auth import (
|
||||
TokenInvalidError,
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, Token
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.schemas.users import Token, UserCreate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
|
||||
|
||||
class TestAuthServiceAuthentication:
|
||||
@@ -17,12 +21,14 @@ class TestAuthServiceAuthentication:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_valid_user(self, async_test_db, async_test_user):
|
||||
"""Test authenticating a user with valid credentials"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(password)
|
||||
await session.commit()
|
||||
@@ -30,9 +36,7 @@ class TestAuthServiceAuthentication:
|
||||
# Authenticate with correct credentials
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
auth_user = await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email=async_test_user.email,
|
||||
password=password
|
||||
db=session, email=async_test_user.email, password=password
|
||||
)
|
||||
|
||||
assert auth_user is not None
|
||||
@@ -42,26 +46,28 @@ class TestAuthServiceAuthentication:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_nonexistent_user(self, async_test_db):
|
||||
"""Test authenticating with an email that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email="nonexistent@example.com",
|
||||
password="password"
|
||||
db=session, email="nonexistent@example.com", password="password"
|
||||
)
|
||||
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_with_wrong_password(self, async_test_db, async_test_user):
|
||||
async def test_authenticate_with_wrong_password(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test authenticating with the wrong password"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(password)
|
||||
await session.commit()
|
||||
@@ -69,9 +75,7 @@ class TestAuthServiceAuthentication:
|
||||
# Authenticate with wrong password
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
auth_user = await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email=async_test_user.email,
|
||||
password="WrongPassword123"
|
||||
db=session, email=async_test_user.email, password="WrongPassword123"
|
||||
)
|
||||
|
||||
assert auth_user is None
|
||||
@@ -79,12 +83,14 @@ class TestAuthServiceAuthentication:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_inactive_user(self, async_test_db, async_test_user):
|
||||
"""Test authenticating an inactive user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password and make user inactive
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(password)
|
||||
user.is_active = False
|
||||
@@ -94,9 +100,7 @@ class TestAuthServiceAuthentication:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(AuthenticationError):
|
||||
await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email=async_test_user.email,
|
||||
password=password
|
||||
db=session, email=async_test_user.email, password=password
|
||||
)
|
||||
|
||||
|
||||
@@ -106,14 +110,14 @@ class TestAuthServiceUserCreation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_user(self, async_test_db):
|
||||
"""Test creating a new user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
user_data = UserCreate(
|
||||
email="newuser@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="+1234567890"
|
||||
phone_number="+1234567890",
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -135,15 +139,17 @@ class TestAuthServiceUserCreation:
|
||||
assert user.is_superuser is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_with_existing_email(self, async_test_db, async_test_user):
|
||||
async def test_create_user_with_existing_email(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test creating a user with an email that already exists"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Use existing email
|
||||
password="TestPassword123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# Should raise AuthenticationError
|
||||
@@ -169,7 +175,7 @@ class TestAuthServiceTokens:
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens(self, async_test_db, async_test_user):
|
||||
"""Test refreshing tokens with a valid refresh token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create initial tokens
|
||||
initial_tokens = AuthService.create_tokens(async_test_user)
|
||||
@@ -177,8 +183,7 @@ class TestAuthServiceTokens:
|
||||
# Refresh tokens
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
new_tokens = await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token=initial_tokens.refresh_token
|
||||
db=session, refresh_token=initial_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Verify new tokens are different from old ones
|
||||
@@ -188,7 +193,7 @@ class TestAuthServiceTokens:
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_with_invalid_token(self, async_test_db):
|
||||
"""Test refreshing tokens with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create an invalid token
|
||||
invalid_token = "invalid.token.string"
|
||||
@@ -197,14 +202,15 @@ class TestAuthServiceTokens:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(TokenInvalidError):
|
||||
await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token=invalid_token
|
||||
db=session, refresh_token=invalid_token
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_with_access_token(self, async_test_db, async_test_user):
|
||||
async def test_refresh_tokens_with_access_token(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test refreshing tokens with an access token instead of refresh token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create tokens
|
||||
tokens = AuthService.create_tokens(async_test_user)
|
||||
@@ -213,18 +219,20 @@ class TestAuthServiceTokens:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(TokenInvalidError):
|
||||
await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token=tokens.access_token
|
||||
db=session, refresh_token=tokens.access_token
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_with_nonexistent_user(self, async_test_db):
|
||||
"""Test refreshing tokens for a user that doesn't exist in the database"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a token for a non-existent user
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
|
||||
with (
|
||||
patch("app.core.auth.decode_token"),
|
||||
patch("app.core.auth.get_token_data") as mock_get_data,
|
||||
):
|
||||
# Mock the token data to return a non-existent user ID
|
||||
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
|
||||
|
||||
@@ -232,8 +240,7 @@ class TestAuthServiceTokens:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(TokenInvalidError):
|
||||
await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token="some.refresh.token"
|
||||
db=session, refresh_token="some.refresh.token"
|
||||
)
|
||||
|
||||
|
||||
@@ -243,12 +250,14 @@ class TestAuthServicePasswordChange:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password(self, async_test_db, async_test_user):
|
||||
"""Test changing a user's password"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
current_password = "CurrentPassword123"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(current_password)
|
||||
await session.commit()
|
||||
@@ -260,7 +269,7 @@ class TestAuthServicePasswordChange:
|
||||
db=session,
|
||||
user_id=async_test_user.id,
|
||||
current_password=current_password,
|
||||
new_password=new_password
|
||||
new_password=new_password,
|
||||
)
|
||||
|
||||
# Verify operation was successful
|
||||
@@ -268,7 +277,9 @@ class TestAuthServicePasswordChange:
|
||||
|
||||
# Verify password was changed
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
updated_user = result.scalar_one_or_none()
|
||||
|
||||
# Verify old password no longer works
|
||||
@@ -278,14 +289,18 @@ class TestAuthServicePasswordChange:
|
||||
assert verify_password(new_password, updated_user.password_hash)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_wrong_current_password(self, async_test_db, async_test_user):
|
||||
async def test_change_password_wrong_current_password(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test changing password with incorrect current password"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
current_password = "CurrentPassword123"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(current_password)
|
||||
await session.commit()
|
||||
@@ -298,19 +313,21 @@ class TestAuthServicePasswordChange:
|
||||
db=session,
|
||||
user_id=async_test_user.id,
|
||||
current_password=wrong_password,
|
||||
new_password="NewPassword456"
|
||||
new_password="NewPassword456",
|
||||
)
|
||||
|
||||
# Verify password was not changed
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
assert verify_password(current_password, user.password_hash)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_nonexistent_user(self, async_test_db):
|
||||
"""Test changing password for a user that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
non_existent_id = uuid.uuid4()
|
||||
|
||||
@@ -320,5 +337,5 @@ class TestAuthServicePasswordChange:
|
||||
db=session,
|
||||
user_id=non_existent_id,
|
||||
current_password="CurrentPassword123",
|
||||
new_password="NewPassword456"
|
||||
new_password="NewPassword456",
|
||||
)
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
"""
|
||||
Tests for email service functionality.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
|
||||
from app.services.email_service import (
|
||||
EmailService,
|
||||
ConsoleEmailBackend,
|
||||
SMTPEmailBackend
|
||||
EmailService,
|
||||
SMTPEmailBackend,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +26,7 @@ class TestConsoleEmailBackend:
|
||||
to=["user@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>",
|
||||
text_content="Test Text"
|
||||
text_content="Test Text",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -37,7 +39,7 @@ class TestConsoleEmailBackend:
|
||||
result = await backend.send_email(
|
||||
to=["user@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>"
|
||||
html_content="<p>Test HTML</p>",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -50,7 +52,7 @@ class TestConsoleEmailBackend:
|
||||
result = await backend.send_email(
|
||||
to=["user1@example.com", "user2@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>"
|
||||
html_content="<p>Test HTML</p>",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -66,7 +68,7 @@ class TestSMTPEmailBackend:
|
||||
host="smtp.example.com",
|
||||
port=587,
|
||||
username="test@example.com",
|
||||
password="password"
|
||||
password="password",
|
||||
)
|
||||
|
||||
assert backend.host == "smtp.example.com"
|
||||
@@ -81,14 +83,14 @@ class TestSMTPEmailBackend:
|
||||
host="smtp.example.com",
|
||||
port=587,
|
||||
username="test@example.com",
|
||||
password="password"
|
||||
password="password",
|
||||
)
|
||||
|
||||
# Should fall back to console backend since SMTP is not implemented
|
||||
result = await backend.send_email(
|
||||
to=["user@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>"
|
||||
html_content="<p>Test HTML</p>",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -114,9 +116,7 @@ class TestEmailService:
|
||||
service = EmailService()
|
||||
|
||||
result = await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="test_token_123",
|
||||
user_name="John"
|
||||
to_email="user@example.com", reset_token="test_token_123", user_name="John"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -127,8 +127,7 @@ class TestEmailService:
|
||||
service = EmailService()
|
||||
|
||||
result = await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="test_token_123"
|
||||
to_email="user@example.com", reset_token="test_token_123"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -142,8 +141,7 @@ class TestEmailService:
|
||||
|
||||
token = "test_reset_token_xyz"
|
||||
await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token=token
|
||||
to_email="user@example.com", reset_token=token
|
||||
)
|
||||
|
||||
# Verify send_email was called
|
||||
@@ -151,7 +149,7 @@ class TestEmailService:
|
||||
call_args = backend_mock.send_email.call_args
|
||||
|
||||
# Check that token is in the HTML content
|
||||
html_content = call_args.kwargs['html_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
assert token in html_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -162,8 +160,7 @@ class TestEmailService:
|
||||
service = EmailService(backend=backend_mock)
|
||||
|
||||
result = await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="test_token"
|
||||
to_email="user@example.com", reset_token="test_token"
|
||||
)
|
||||
|
||||
assert result is False
|
||||
@@ -176,7 +173,7 @@ class TestEmailService:
|
||||
result = await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="verification_token_123",
|
||||
user_name="Jane"
|
||||
user_name="Jane",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -187,8 +184,7 @@ class TestEmailService:
|
||||
service = EmailService()
|
||||
|
||||
result = await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="verification_token_123"
|
||||
to_email="user@example.com", verification_token="verification_token_123"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -202,8 +198,7 @@ class TestEmailService:
|
||||
|
||||
token = "test_verification_token_xyz"
|
||||
await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token=token
|
||||
to_email="user@example.com", verification_token=token
|
||||
)
|
||||
|
||||
# Verify send_email was called
|
||||
@@ -211,7 +206,7 @@ class TestEmailService:
|
||||
call_args = backend_mock.send_email.call_args
|
||||
|
||||
# Check that token is in the HTML content
|
||||
html_content = call_args.kwargs['html_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
assert token in html_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -222,8 +217,7 @@ class TestEmailService:
|
||||
service = EmailService(backend=backend_mock)
|
||||
|
||||
result = await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="test_token"
|
||||
to_email="user@example.com", verification_token="test_token"
|
||||
)
|
||||
|
||||
assert result is False
|
||||
@@ -236,14 +230,12 @@ class TestEmailService:
|
||||
service = EmailService(backend=backend_mock)
|
||||
|
||||
await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="token123",
|
||||
user_name="Test User"
|
||||
to_email="user@example.com", reset_token="token123", user_name="Test User"
|
||||
)
|
||||
|
||||
call_args = backend_mock.send_email.call_args
|
||||
html_content = call_args.kwargs['html_content']
|
||||
text_content = call_args.kwargs['text_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
text_content = call_args.kwargs["text_content"]
|
||||
|
||||
# Check HTML content
|
||||
assert "Password Reset" in html_content
|
||||
@@ -251,7 +243,9 @@ class TestEmailService:
|
||||
assert "Test User" in html_content
|
||||
|
||||
# Check text content
|
||||
assert "Password Reset" in text_content or "password reset" in text_content.lower()
|
||||
assert (
|
||||
"Password Reset" in text_content or "password reset" in text_content.lower()
|
||||
)
|
||||
assert "token123" in text_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -264,12 +258,12 @@ class TestEmailService:
|
||||
await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="verify123",
|
||||
user_name="Test User"
|
||||
user_name="Test User",
|
||||
)
|
||||
|
||||
call_args = backend_mock.send_email.call_args
|
||||
html_content = call_args.kwargs['html_content']
|
||||
text_content = call_args.kwargs['text_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
text_content = call_args.kwargs["text_content"]
|
||||
|
||||
# Check HTML content
|
||||
assert "Verify" in html_content
|
||||
|
||||
@@ -2,23 +2,27 @@
|
||||
"""
|
||||
Comprehensive tests for session cleanup service.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
class TestCleanupExpiredSessions:
|
||||
"""Tests for cleanup_expired_sessions function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_success(self, async_test_db, async_test_user):
|
||||
async def test_cleanup_expired_sessions_success(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test successful cleanup of expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create mix of sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -30,9 +34,9 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# 2. Inactive, expired, old (SHOULD be deleted)
|
||||
@@ -43,9 +47,9 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=10),
|
||||
created_at=datetime.now(UTC) - timedelta(days=40),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# 3. Inactive, expired, recent (should NOT be deleted - within keep_days)
|
||||
@@ -56,17 +60,23 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.3",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
created_at=datetime.now(UTC) - timedelta(days=5),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
session.add_all([active_session, old_expired_session, recent_expired_session])
|
||||
session.add_all(
|
||||
[active_session, old_expired_session, recent_expired_session]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Mock SessionLocal to return our test session
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
# Should only delete old_expired_session
|
||||
@@ -85,7 +95,7 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user):
|
||||
"""Test cleanup when no sessions meet deletion criteria."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active = UserSession(
|
||||
@@ -95,15 +105,19 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
created_at=datetime.now(UTC),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(active)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 0
|
||||
@@ -111,10 +125,14 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_empty_database(self, async_test_db):
|
||||
"""Test cleanup with no sessions in database."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 0
|
||||
@@ -122,7 +140,7 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user):
|
||||
"""Test cleanup with keep_days=0 deletes all inactive expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
today_expired = UserSession(
|
||||
@@ -132,15 +150,19 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
created_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(today_expired)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=0)
|
||||
|
||||
assert deleted_count == 1
|
||||
@@ -148,7 +170,7 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup uses bulk DELETE for many sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create 50 expired sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -161,16 +183,20 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=10),
|
||||
created_at=datetime.now(UTC) - timedelta(days=40),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
sessions_to_add.append(expired)
|
||||
session.add_all(sessions_to_add)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 50
|
||||
@@ -178,14 +204,20 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_database_error_returns_zero(self, async_test_db):
|
||||
"""Test cleanup returns 0 on database errors (doesn't crash)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Mock session_crud.cleanup_expired to raise error
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup:
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
with patch(
|
||||
"app.services.session_cleanup.session_crud.cleanup_expired"
|
||||
) as mock_cleanup:
|
||||
mock_cleanup.side_effect = Exception("Database connection lost")
|
||||
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
# Should not crash, should return 0
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
@@ -198,7 +230,7 @@ class TestGetSessionStatistics:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_with_sessions(self, async_test_db, async_test_user):
|
||||
"""Test getting session statistics with various session types."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# 2 active, not expired
|
||||
@@ -210,9 +242,9 @@ class TestGetSessionStatistics:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
created_at=datetime.now(UTC),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(active)
|
||||
|
||||
@@ -225,9 +257,9 @@ class TestGetSessionStatistics:
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
created_at=datetime.now(UTC) - timedelta(days=2),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(inactive)
|
||||
|
||||
@@ -239,16 +271,20 @@ class TestGetSessionStatistics:
|
||||
ip_address="192.168.1.3",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(expired_active)
|
||||
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats["total"] == 6
|
||||
@@ -259,10 +295,14 @@ class TestGetSessionStatistics:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_empty_database(self, async_test_db):
|
||||
"""Test getting statistics with no sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats["total"] == 0
|
||||
@@ -271,9 +311,11 @@ class TestGetSessionStatistics:
|
||||
assert stats["expired"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_database_error_returns_empty_dict(self, async_test_db):
|
||||
async def test_get_statistics_database_error_returns_empty_dict(
|
||||
self, async_test_db
|
||||
):
|
||||
"""Test statistics returns empty dict on database errors."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, _AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a mock that raises on execute
|
||||
mock_session = AsyncMock()
|
||||
@@ -283,8 +325,12 @@ class TestGetSessionStatistics:
|
||||
async def mock_session_local():
|
||||
yield mock_session
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=mock_session_local()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=mock_session_local(),
|
||||
):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats == {}
|
||||
@@ -294,9 +340,11 @@ class TestConcurrentCleanup:
|
||||
"""Tests for concurrent cleanup scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_cleanup_no_duplicate_deletes(self, async_test_db, async_test_user):
|
||||
async def test_concurrent_cleanup_no_duplicate_deletes(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test concurrent cleanups don't cause race conditions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create 10 expired sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -308,20 +356,24 @@ class TestConcurrentCleanup:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=10),
|
||||
created_at=datetime.now(UTC) - timedelta(days=40),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(expired)
|
||||
await session.commit()
|
||||
|
||||
# Run two cleanups concurrently
|
||||
# Use side_effect to return fresh session instances for each call
|
||||
with patch('app.services.session_cleanup.SessionLocal', side_effect=lambda: AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
side_effect=lambda: AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
results = await asyncio.gather(
|
||||
cleanup_expired_sessions(keep_days=30),
|
||||
cleanup_expired_sessions(keep_days=30)
|
||||
cleanup_expired_sessions(keep_days=30),
|
||||
)
|
||||
|
||||
# Both should report deleting sessions (may overlap due to transaction timing)
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
"""
|
||||
Tests for database initialization script.
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.init_db import init_db
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.config import settings
|
||||
from app.init_db import init_db
|
||||
|
||||
|
||||
class TestInitDb:
|
||||
@@ -16,69 +17,86 @@ class TestInitDb:
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_creates_superuser_when_not_exists(self, async_test_db):
|
||||
"""Test that init_db creates a superuser when one doesn't exist."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock the SessionLocal to use our test database
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
with patch("app.init_db.SessionLocal", SessionLocal):
|
||||
# Mock settings to provide test credentials
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test_admin@example.com'):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestAdmin123!'):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_EMAIL", "test_admin@example.com"
|
||||
):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_PASSWORD", "TestAdmin123!"
|
||||
):
|
||||
# Run init_db
|
||||
user = await init_db()
|
||||
|
||||
# Verify superuser was created
|
||||
assert user is not None
|
||||
assert user.email == 'test_admin@example.com'
|
||||
assert user.email == "test_admin@example.com"
|
||||
assert user.is_superuser is True
|
||||
assert user.first_name == 'Admin'
|
||||
assert user.last_name == 'User'
|
||||
assert user.first_name == "Admin"
|
||||
assert user.last_name == "User"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_returns_existing_superuser(self, async_test_db, async_test_user):
|
||||
async def test_init_db_returns_existing_superuser(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test that init_db returns existing superuser instead of creating duplicate."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock the SessionLocal to use our test database
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
with patch("app.init_db.SessionLocal", SessionLocal):
|
||||
# Mock settings to match async_test_user's email
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'testuser@example.com'):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_EMAIL", "testuser@example.com"
|
||||
):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
|
||||
):
|
||||
# Run init_db
|
||||
user = await init_db()
|
||||
|
||||
# Verify it returns the existing user
|
||||
assert user is not None
|
||||
assert user.id == async_test_user.id
|
||||
assert user.email == 'testuser@example.com'
|
||||
assert user.email == "testuser@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_uses_default_credentials(self, async_test_db):
|
||||
"""Test that init_db uses default credentials when env vars not set."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock the SessionLocal to use our test database
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
with patch("app.init_db.SessionLocal", SessionLocal):
|
||||
# Mock settings to have None values (not configured)
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', None):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', None):
|
||||
with patch.object(settings, "FIRST_SUPERUSER_EMAIL", None):
|
||||
with patch.object(settings, "FIRST_SUPERUSER_PASSWORD", None):
|
||||
# Run init_db
|
||||
user = await init_db()
|
||||
|
||||
# Verify superuser was created with defaults
|
||||
assert user is not None
|
||||
assert user.email == 'admin@example.com'
|
||||
assert user.email == "admin@example.com"
|
||||
assert user.is_superuser is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_handles_database_errors(self, async_test_db):
|
||||
"""Test that init_db handles database errors gracefully."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock user_crud.get_by_email to raise an exception
|
||||
with patch('app.init_db.user_crud.get_by_email', side_effect=Exception("Database error")):
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test@example.com'):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
|
||||
with patch(
|
||||
"app.init_db.user_crud.get_by_email",
|
||||
side_effect=Exception("Database error"),
|
||||
):
|
||||
with patch("app.init_db.SessionLocal", SessionLocal):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_EMAIL", "test@example.com"
|
||||
):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
|
||||
):
|
||||
# Run init_db and expect it to raise
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
await init_db()
|
||||
|
||||
@@ -2,18 +2,18 @@
|
||||
"""
|
||||
Comprehensive tests for device utility functions.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.utils.device import (
|
||||
extract_device_info,
|
||||
parse_device_name,
|
||||
extract_browser,
|
||||
extract_device_info,
|
||||
get_client_ip,
|
||||
get_device_type,
|
||||
is_mobile_device,
|
||||
get_device_type
|
||||
parse_device_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -138,7 +138,9 @@ class TestExtractBrowser:
|
||||
|
||||
def test_extract_browser_edge_legacy(self):
|
||||
"""Test extracting legacy Edge browser."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
|
||||
ua = (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
|
||||
)
|
||||
result = extract_browser(ua)
|
||||
assert result == "Edge"
|
||||
|
||||
@@ -249,7 +251,7 @@ class TestGetClientIp:
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {
|
||||
"x-forwarded-for": "192.168.1.100",
|
||||
"x-real-ip": "192.168.1.200"
|
||||
"x-real-ip": "192.168.1.200",
|
||||
}
|
||||
request.client = Mock()
|
||||
request.client.host = "192.168.1.50"
|
||||
@@ -385,7 +387,7 @@ class TestExtractDeviceInfo:
|
||||
request.headers = {
|
||||
"user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)",
|
||||
"x-device-id": "device-123-456",
|
||||
"x-forwarded-for": "192.168.1.100"
|
||||
"x-forwarded-for": "192.168.1.100",
|
||||
}
|
||||
request.client = None
|
||||
|
||||
|
||||
@@ -2,19 +2,21 @@
|
||||
"""
|
||||
Tests for security utility functions.
|
||||
"""
|
||||
import time
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.utils.security import (
|
||||
create_upload_token,
|
||||
verify_upload_token,
|
||||
create_password_reset_token,
|
||||
verify_password_reset_token,
|
||||
create_email_verification_token,
|
||||
verify_email_verification_token
|
||||
create_password_reset_token,
|
||||
create_upload_token,
|
||||
verify_email_verification_token,
|
||||
verify_password_reset_token,
|
||||
verify_upload_token,
|
||||
)
|
||||
|
||||
|
||||
@@ -31,7 +33,7 @@ class TestCreateUploadToken:
|
||||
|
||||
# Token should be base64 encoded
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
assert "payload" in token_data
|
||||
assert "signature" in token_data
|
||||
@@ -46,7 +48,7 @@ class TestCreateUploadToken:
|
||||
token = create_upload_token(file_path, content_type)
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
@@ -62,7 +64,7 @@ class TestCreateUploadToken:
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
@@ -74,11 +76,13 @@ class TestCreateUploadToken:
|
||||
"""Test token creation with custom expiration time."""
|
||||
custom_exp = 600 # 10 minutes
|
||||
before = int(time.time())
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp)
|
||||
token = create_upload_token(
|
||||
"/uploads/test.jpg", "image/jpeg", expires_in=custom_exp
|
||||
)
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
@@ -92,11 +96,11 @@ class TestCreateUploadToken:
|
||||
token2 = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode both tokens
|
||||
decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8'))
|
||||
decoded1 = base64.urlsafe_b64decode(token1.encode("utf-8"))
|
||||
token_data1 = json.loads(decoded1)
|
||||
nonce1 = token_data1["payload"]["nonce"]
|
||||
|
||||
decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8'))
|
||||
decoded2 = base64.urlsafe_b64decode(token2.encode("utf-8"))
|
||||
token_data2 = json.loads(decoded2)
|
||||
nonce2 = token_data2["payload"]["nonce"]
|
||||
|
||||
@@ -133,7 +137,7 @@ class TestVerifyUploadToken:
|
||||
current_time = 1000000
|
||||
mock_time.time = MagicMock(return_value=current_time)
|
||||
|
||||
with patch('app.utils.security.time', mock_time):
|
||||
with patch("app.utils.security.time", mock_time):
|
||||
# Create token that "expires" at current_time + 1
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1)
|
||||
|
||||
@@ -149,13 +153,15 @@ class TestVerifyUploadToken:
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["signature"] = "invalid_signature"
|
||||
|
||||
# Re-encode the tampered token
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
@@ -165,13 +171,15 @@ class TestVerifyUploadToken:
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify payload, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["path"] = "/uploads/hacked.exe"
|
||||
|
||||
# Re-encode the tampered token (signature won't match)
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
@@ -194,7 +202,9 @@ class TestVerifyUploadToken:
|
||||
"""Test that tokens with invalid JSON are rejected."""
|
||||
# Create a base64 string that decodes to invalid JSON
|
||||
invalid_json = "not valid json"
|
||||
invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8')
|
||||
invalid_token = base64.urlsafe_b64encode(invalid_json.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
payload = verify_upload_token(invalid_token)
|
||||
assert payload is None
|
||||
@@ -207,11 +217,13 @@ class TestVerifyUploadToken:
|
||||
"path": "/uploads/test.jpg"
|
||||
# Missing content_type, exp, nonce
|
||||
},
|
||||
"signature": "some_signature"
|
||||
"signature": "some_signature",
|
||||
}
|
||||
|
||||
incomplete_json = json.dumps(incomplete_data)
|
||||
incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8')
|
||||
incomplete_token = base64.urlsafe_b64encode(
|
||||
incomplete_json.encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
payload = verify_upload_token(incomplete_token)
|
||||
assert payload is None
|
||||
@@ -266,7 +278,7 @@ class TestPasswordResetTokens:
|
||||
email = "user@example.com"
|
||||
|
||||
# Create token that expires in 1 second
|
||||
with patch('app.utils.security.time') as mock_time:
|
||||
with patch("app.utils.security.time") as mock_time:
|
||||
mock_time.time = MagicMock(return_value=1000000)
|
||||
token = create_password_reset_token(email, expires_in=1)
|
||||
|
||||
@@ -287,12 +299,14 @@ class TestPasswordResetTokens:
|
||||
token = create_password_reset_token(email)
|
||||
|
||||
# Decode and tamper
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["email"] = "hacker@example.com"
|
||||
|
||||
# Re-encode
|
||||
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
|
||||
tampered = base64.urlsafe_b64encode(
|
||||
json.dumps(token_data).encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
verified_email = verify_password_reset_token(tampered)
|
||||
assert verified_email is None
|
||||
@@ -312,14 +326,14 @@ class TestPasswordResetTokens:
|
||||
email = "user@example.com"
|
||||
custom_exp = 7200 # 2 hours
|
||||
|
||||
with patch('app.utils.security.time') as mock_time:
|
||||
with patch("app.utils.security.time") as mock_time:
|
||||
current_time = 1000000
|
||||
mock_time.time = MagicMock(return_value=current_time)
|
||||
|
||||
token = create_password_reset_token(email, expires_in=custom_exp)
|
||||
|
||||
# Decode to check expiration
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
|
||||
assert token_data["payload"]["exp"] == current_time + custom_exp
|
||||
@@ -350,7 +364,7 @@ class TestEmailVerificationTokens:
|
||||
"""Test that expired verification tokens are rejected."""
|
||||
email = "user@example.com"
|
||||
|
||||
with patch('app.utils.security.time') as mock_time:
|
||||
with patch("app.utils.security.time") as mock_time:
|
||||
mock_time.time = MagicMock(return_value=1000000)
|
||||
token = create_email_verification_token(email, expires_in=1)
|
||||
|
||||
@@ -371,12 +385,14 @@ class TestEmailVerificationTokens:
|
||||
token = create_email_verification_token(email)
|
||||
|
||||
# Decode and tamper
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["email"] = "hacker@example.com"
|
||||
|
||||
# Re-encode
|
||||
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
|
||||
tampered = base64.urlsafe_b64encode(
|
||||
json.dumps(token_data).encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
verified_email = verify_email_verification_token(tampered)
|
||||
assert verified_email is None
|
||||
@@ -395,14 +411,14 @@ class TestEmailVerificationTokens:
|
||||
"""Test email verification token with default 24-hour expiration."""
|
||||
email = "user@example.com"
|
||||
|
||||
with patch('app.utils.security.time') as mock_time:
|
||||
with patch("app.utils.security.time") as mock_time:
|
||||
current_time = 1000000
|
||||
mock_time.time = MagicMock(return_value=current_time)
|
||||
|
||||
token = create_email_verification_token(email)
|
||||
|
||||
# Decode to check expiration (should be 86400 seconds = 24 hours)
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
|
||||
assert token_data["payload"]["exp"] == current_time + 86400
|
||||
|
||||
Reference in New Issue
Block a user