forked from cardosofelipe/fast-next-template
Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff
- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest). - Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting. - Updated `requirements.txt` to include Ruff and remove replaced tools. - Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
@@ -1,14 +1,18 @@
|
||||
# tests/services/test_auth_service.py
|
||||
import uuid
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
|
||||
from app.core.auth import (
|
||||
TokenInvalidError,
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, Token
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.schemas.users import Token, UserCreate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
|
||||
|
||||
class TestAuthServiceAuthentication:
|
||||
@@ -17,12 +21,14 @@ class TestAuthServiceAuthentication:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_valid_user(self, async_test_db, async_test_user):
|
||||
"""Test authenticating a user with valid credentials"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(password)
|
||||
await session.commit()
|
||||
@@ -30,9 +36,7 @@ class TestAuthServiceAuthentication:
|
||||
# Authenticate with correct credentials
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
auth_user = await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email=async_test_user.email,
|
||||
password=password
|
||||
db=session, email=async_test_user.email, password=password
|
||||
)
|
||||
|
||||
assert auth_user is not None
|
||||
@@ -42,26 +46,28 @@ class TestAuthServiceAuthentication:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_nonexistent_user(self, async_test_db):
|
||||
"""Test authenticating with an email that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email="nonexistent@example.com",
|
||||
password="password"
|
||||
db=session, email="nonexistent@example.com", password="password"
|
||||
)
|
||||
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_with_wrong_password(self, async_test_db, async_test_user):
|
||||
async def test_authenticate_with_wrong_password(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test authenticating with the wrong password"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(password)
|
||||
await session.commit()
|
||||
@@ -69,9 +75,7 @@ class TestAuthServiceAuthentication:
|
||||
# Authenticate with wrong password
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
auth_user = await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email=async_test_user.email,
|
||||
password="WrongPassword123"
|
||||
db=session, email=async_test_user.email, password="WrongPassword123"
|
||||
)
|
||||
|
||||
assert auth_user is None
|
||||
@@ -79,12 +83,14 @@ class TestAuthServiceAuthentication:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_inactive_user(self, async_test_db, async_test_user):
|
||||
"""Test authenticating an inactive user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password and make user inactive
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(password)
|
||||
user.is_active = False
|
||||
@@ -94,9 +100,7 @@ class TestAuthServiceAuthentication:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(AuthenticationError):
|
||||
await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email=async_test_user.email,
|
||||
password=password
|
||||
db=session, email=async_test_user.email, password=password
|
||||
)
|
||||
|
||||
|
||||
@@ -106,14 +110,14 @@ class TestAuthServiceUserCreation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_user(self, async_test_db):
|
||||
"""Test creating a new user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
user_data = UserCreate(
|
||||
email="newuser@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="+1234567890"
|
||||
phone_number="+1234567890",
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -135,15 +139,17 @@ class TestAuthServiceUserCreation:
|
||||
assert user.is_superuser is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_with_existing_email(self, async_test_db, async_test_user):
|
||||
async def test_create_user_with_existing_email(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test creating a user with an email that already exists"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Use existing email
|
||||
password="TestPassword123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# Should raise AuthenticationError
|
||||
@@ -169,7 +175,7 @@ class TestAuthServiceTokens:
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens(self, async_test_db, async_test_user):
|
||||
"""Test refreshing tokens with a valid refresh token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create initial tokens
|
||||
initial_tokens = AuthService.create_tokens(async_test_user)
|
||||
@@ -177,8 +183,7 @@ class TestAuthServiceTokens:
|
||||
# Refresh tokens
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
new_tokens = await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token=initial_tokens.refresh_token
|
||||
db=session, refresh_token=initial_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Verify new tokens are different from old ones
|
||||
@@ -188,7 +193,7 @@ class TestAuthServiceTokens:
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_with_invalid_token(self, async_test_db):
|
||||
"""Test refreshing tokens with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create an invalid token
|
||||
invalid_token = "invalid.token.string"
|
||||
@@ -197,14 +202,15 @@ class TestAuthServiceTokens:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(TokenInvalidError):
|
||||
await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token=invalid_token
|
||||
db=session, refresh_token=invalid_token
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_with_access_token(self, async_test_db, async_test_user):
|
||||
async def test_refresh_tokens_with_access_token(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test refreshing tokens with an access token instead of refresh token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create tokens
|
||||
tokens = AuthService.create_tokens(async_test_user)
|
||||
@@ -213,18 +219,20 @@ class TestAuthServiceTokens:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(TokenInvalidError):
|
||||
await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token=tokens.access_token
|
||||
db=session, refresh_token=tokens.access_token
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_with_nonexistent_user(self, async_test_db):
|
||||
"""Test refreshing tokens for a user that doesn't exist in the database"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a token for a non-existent user
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
|
||||
with (
|
||||
patch("app.core.auth.decode_token"),
|
||||
patch("app.core.auth.get_token_data") as mock_get_data,
|
||||
):
|
||||
# Mock the token data to return a non-existent user ID
|
||||
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
|
||||
|
||||
@@ -232,8 +240,7 @@ class TestAuthServiceTokens:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(TokenInvalidError):
|
||||
await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token="some.refresh.token"
|
||||
db=session, refresh_token="some.refresh.token"
|
||||
)
|
||||
|
||||
|
||||
@@ -243,12 +250,14 @@ class TestAuthServicePasswordChange:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password(self, async_test_db, async_test_user):
|
||||
"""Test changing a user's password"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
current_password = "CurrentPassword123"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(current_password)
|
||||
await session.commit()
|
||||
@@ -260,7 +269,7 @@ class TestAuthServicePasswordChange:
|
||||
db=session,
|
||||
user_id=async_test_user.id,
|
||||
current_password=current_password,
|
||||
new_password=new_password
|
||||
new_password=new_password,
|
||||
)
|
||||
|
||||
# Verify operation was successful
|
||||
@@ -268,7 +277,9 @@ class TestAuthServicePasswordChange:
|
||||
|
||||
# Verify password was changed
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
updated_user = result.scalar_one_or_none()
|
||||
|
||||
# Verify old password no longer works
|
||||
@@ -278,14 +289,18 @@ class TestAuthServicePasswordChange:
|
||||
assert verify_password(new_password, updated_user.password_hash)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_wrong_current_password(self, async_test_db, async_test_user):
|
||||
async def test_change_password_wrong_current_password(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test changing password with incorrect current password"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
current_password = "CurrentPassword123"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(current_password)
|
||||
await session.commit()
|
||||
@@ -298,19 +313,21 @@ class TestAuthServicePasswordChange:
|
||||
db=session,
|
||||
user_id=async_test_user.id,
|
||||
current_password=wrong_password,
|
||||
new_password="NewPassword456"
|
||||
new_password="NewPassword456",
|
||||
)
|
||||
|
||||
# Verify password was not changed
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
assert verify_password(current_password, user.password_hash)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_nonexistent_user(self, async_test_db):
|
||||
"""Test changing password for a user that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
non_existent_id = uuid.uuid4()
|
||||
|
||||
@@ -320,5 +337,5 @@ class TestAuthServicePasswordChange:
|
||||
db=session,
|
||||
user_id=non_existent_id,
|
||||
current_password="CurrentPassword123",
|
||||
new_password="NewPassword456"
|
||||
new_password="NewPassword456",
|
||||
)
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
"""
|
||||
Tests for email service functionality.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
|
||||
from app.services.email_service import (
|
||||
EmailService,
|
||||
ConsoleEmailBackend,
|
||||
SMTPEmailBackend
|
||||
EmailService,
|
||||
SMTPEmailBackend,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +26,7 @@ class TestConsoleEmailBackend:
|
||||
to=["user@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>",
|
||||
text_content="Test Text"
|
||||
text_content="Test Text",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -37,7 +39,7 @@ class TestConsoleEmailBackend:
|
||||
result = await backend.send_email(
|
||||
to=["user@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>"
|
||||
html_content="<p>Test HTML</p>",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -50,7 +52,7 @@ class TestConsoleEmailBackend:
|
||||
result = await backend.send_email(
|
||||
to=["user1@example.com", "user2@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>"
|
||||
html_content="<p>Test HTML</p>",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -66,7 +68,7 @@ class TestSMTPEmailBackend:
|
||||
host="smtp.example.com",
|
||||
port=587,
|
||||
username="test@example.com",
|
||||
password="password"
|
||||
password="password",
|
||||
)
|
||||
|
||||
assert backend.host == "smtp.example.com"
|
||||
@@ -81,14 +83,14 @@ class TestSMTPEmailBackend:
|
||||
host="smtp.example.com",
|
||||
port=587,
|
||||
username="test@example.com",
|
||||
password="password"
|
||||
password="password",
|
||||
)
|
||||
|
||||
# Should fall back to console backend since SMTP is not implemented
|
||||
result = await backend.send_email(
|
||||
to=["user@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>"
|
||||
html_content="<p>Test HTML</p>",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -114,9 +116,7 @@ class TestEmailService:
|
||||
service = EmailService()
|
||||
|
||||
result = await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="test_token_123",
|
||||
user_name="John"
|
||||
to_email="user@example.com", reset_token="test_token_123", user_name="John"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -127,8 +127,7 @@ class TestEmailService:
|
||||
service = EmailService()
|
||||
|
||||
result = await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="test_token_123"
|
||||
to_email="user@example.com", reset_token="test_token_123"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -142,8 +141,7 @@ class TestEmailService:
|
||||
|
||||
token = "test_reset_token_xyz"
|
||||
await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token=token
|
||||
to_email="user@example.com", reset_token=token
|
||||
)
|
||||
|
||||
# Verify send_email was called
|
||||
@@ -151,7 +149,7 @@ class TestEmailService:
|
||||
call_args = backend_mock.send_email.call_args
|
||||
|
||||
# Check that token is in the HTML content
|
||||
html_content = call_args.kwargs['html_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
assert token in html_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -162,8 +160,7 @@ class TestEmailService:
|
||||
service = EmailService(backend=backend_mock)
|
||||
|
||||
result = await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="test_token"
|
||||
to_email="user@example.com", reset_token="test_token"
|
||||
)
|
||||
|
||||
assert result is False
|
||||
@@ -176,7 +173,7 @@ class TestEmailService:
|
||||
result = await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="verification_token_123",
|
||||
user_name="Jane"
|
||||
user_name="Jane",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -187,8 +184,7 @@ class TestEmailService:
|
||||
service = EmailService()
|
||||
|
||||
result = await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="verification_token_123"
|
||||
to_email="user@example.com", verification_token="verification_token_123"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -202,8 +198,7 @@ class TestEmailService:
|
||||
|
||||
token = "test_verification_token_xyz"
|
||||
await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token=token
|
||||
to_email="user@example.com", verification_token=token
|
||||
)
|
||||
|
||||
# Verify send_email was called
|
||||
@@ -211,7 +206,7 @@ class TestEmailService:
|
||||
call_args = backend_mock.send_email.call_args
|
||||
|
||||
# Check that token is in the HTML content
|
||||
html_content = call_args.kwargs['html_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
assert token in html_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -222,8 +217,7 @@ class TestEmailService:
|
||||
service = EmailService(backend=backend_mock)
|
||||
|
||||
result = await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="test_token"
|
||||
to_email="user@example.com", verification_token="test_token"
|
||||
)
|
||||
|
||||
assert result is False
|
||||
@@ -236,14 +230,12 @@ class TestEmailService:
|
||||
service = EmailService(backend=backend_mock)
|
||||
|
||||
await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="token123",
|
||||
user_name="Test User"
|
||||
to_email="user@example.com", reset_token="token123", user_name="Test User"
|
||||
)
|
||||
|
||||
call_args = backend_mock.send_email.call_args
|
||||
html_content = call_args.kwargs['html_content']
|
||||
text_content = call_args.kwargs['text_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
text_content = call_args.kwargs["text_content"]
|
||||
|
||||
# Check HTML content
|
||||
assert "Password Reset" in html_content
|
||||
@@ -251,7 +243,9 @@ class TestEmailService:
|
||||
assert "Test User" in html_content
|
||||
|
||||
# Check text content
|
||||
assert "Password Reset" in text_content or "password reset" in text_content.lower()
|
||||
assert (
|
||||
"Password Reset" in text_content or "password reset" in text_content.lower()
|
||||
)
|
||||
assert "token123" in text_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -264,12 +258,12 @@ class TestEmailService:
|
||||
await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="verify123",
|
||||
user_name="Test User"
|
||||
user_name="Test User",
|
||||
)
|
||||
|
||||
call_args = backend_mock.send_email.call_args
|
||||
html_content = call_args.kwargs['html_content']
|
||||
text_content = call_args.kwargs['text_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
text_content = call_args.kwargs["text_content"]
|
||||
|
||||
# Check HTML content
|
||||
assert "Verify" in html_content
|
||||
|
||||
@@ -2,23 +2,27 @@
|
||||
"""
|
||||
Comprehensive tests for session cleanup service.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
class TestCleanupExpiredSessions:
|
||||
"""Tests for cleanup_expired_sessions function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_success(self, async_test_db, async_test_user):
|
||||
async def test_cleanup_expired_sessions_success(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test successful cleanup of expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create mix of sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -30,9 +34,9 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# 2. Inactive, expired, old (SHOULD be deleted)
|
||||
@@ -43,9 +47,9 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=10),
|
||||
created_at=datetime.now(UTC) - timedelta(days=40),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# 3. Inactive, expired, recent (should NOT be deleted - within keep_days)
|
||||
@@ -56,17 +60,23 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.3",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
created_at=datetime.now(UTC) - timedelta(days=5),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
session.add_all([active_session, old_expired_session, recent_expired_session])
|
||||
session.add_all(
|
||||
[active_session, old_expired_session, recent_expired_session]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Mock SessionLocal to return our test session
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
# Should only delete old_expired_session
|
||||
@@ -85,7 +95,7 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user):
|
||||
"""Test cleanup when no sessions meet deletion criteria."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active = UserSession(
|
||||
@@ -95,15 +105,19 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
created_at=datetime.now(UTC),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(active)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 0
|
||||
@@ -111,10 +125,14 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_empty_database(self, async_test_db):
|
||||
"""Test cleanup with no sessions in database."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 0
|
||||
@@ -122,7 +140,7 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user):
|
||||
"""Test cleanup with keep_days=0 deletes all inactive expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
today_expired = UserSession(
|
||||
@@ -132,15 +150,19 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
created_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(today_expired)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=0)
|
||||
|
||||
assert deleted_count == 1
|
||||
@@ -148,7 +170,7 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup uses bulk DELETE for many sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create 50 expired sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -161,16 +183,20 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=10),
|
||||
created_at=datetime.now(UTC) - timedelta(days=40),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
sessions_to_add.append(expired)
|
||||
session.add_all(sessions_to_add)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 50
|
||||
@@ -178,14 +204,20 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_database_error_returns_zero(self, async_test_db):
|
||||
"""Test cleanup returns 0 on database errors (doesn't crash)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Mock session_crud.cleanup_expired to raise error
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup:
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
with patch(
|
||||
"app.services.session_cleanup.session_crud.cleanup_expired"
|
||||
) as mock_cleanup:
|
||||
mock_cleanup.side_effect = Exception("Database connection lost")
|
||||
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
# Should not crash, should return 0
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
@@ -198,7 +230,7 @@ class TestGetSessionStatistics:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_with_sessions(self, async_test_db, async_test_user):
|
||||
"""Test getting session statistics with various session types."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# 2 active, not expired
|
||||
@@ -210,9 +242,9 @@ class TestGetSessionStatistics:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
created_at=datetime.now(UTC),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(active)
|
||||
|
||||
@@ -225,9 +257,9 @@ class TestGetSessionStatistics:
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
created_at=datetime.now(UTC) - timedelta(days=2),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(inactive)
|
||||
|
||||
@@ -239,16 +271,20 @@ class TestGetSessionStatistics:
|
||||
ip_address="192.168.1.3",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(expired_active)
|
||||
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats["total"] == 6
|
||||
@@ -259,10 +295,14 @@ class TestGetSessionStatistics:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_empty_database(self, async_test_db):
|
||||
"""Test getting statistics with no sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats["total"] == 0
|
||||
@@ -271,9 +311,11 @@ class TestGetSessionStatistics:
|
||||
assert stats["expired"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_database_error_returns_empty_dict(self, async_test_db):
|
||||
async def test_get_statistics_database_error_returns_empty_dict(
|
||||
self, async_test_db
|
||||
):
|
||||
"""Test statistics returns empty dict on database errors."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, _AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a mock that raises on execute
|
||||
mock_session = AsyncMock()
|
||||
@@ -283,8 +325,12 @@ class TestGetSessionStatistics:
|
||||
async def mock_session_local():
|
||||
yield mock_session
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=mock_session_local()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=mock_session_local(),
|
||||
):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats == {}
|
||||
@@ -294,9 +340,11 @@ class TestConcurrentCleanup:
|
||||
"""Tests for concurrent cleanup scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_cleanup_no_duplicate_deletes(self, async_test_db, async_test_user):
|
||||
async def test_concurrent_cleanup_no_duplicate_deletes(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test concurrent cleanups don't cause race conditions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create 10 expired sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -308,20 +356,24 @@ class TestConcurrentCleanup:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=10),
|
||||
created_at=datetime.now(UTC) - timedelta(days=40),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(expired)
|
||||
await session.commit()
|
||||
|
||||
# Run two cleanups concurrently
|
||||
# Use side_effect to return fresh session instances for each call
|
||||
with patch('app.services.session_cleanup.SessionLocal', side_effect=lambda: AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
side_effect=lambda: AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
results = await asyncio.gather(
|
||||
cleanup_expired_sessions(keep_days=30),
|
||||
cleanup_expired_sessions(keep_days=30)
|
||||
cleanup_expired_sessions(keep_days=30),
|
||||
)
|
||||
|
||||
# Both should report deleting sessions (may overlap due to transaction timing)
|
||||
|
||||
Reference in New Issue
Block a user