Remove legacy test files for auth, rate limiting, and users
- Deleted outdated backend test cases (`test_auth.py`, `test_rate_limiting.py`, `test_users.py`) to clean up deprecated test structure. - These tests are now redundant with newer async test implementations and improved coverage.
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
testpaths = tests
|
testpaths = tests
|
||||||
python_files = test_*.py
|
python_files = test_*.py
|
||||||
addopts = --disable-warnings
|
addopts = --disable-warnings -n auto
|
||||||
markers =
|
markers =
|
||||||
sqlite: marks tests that should run on SQLite (mocked).
|
sqlite: marks tests that should run on SQLite (mocked).
|
||||||
postgres: marks tests that require a real PostgreSQL database.
|
postgres: marks tests that require a real PostgreSQL database.
|
||||||
|
|||||||
@@ -1,401 +0,0 @@
|
|||||||
# tests/api/routes/test_auth.py
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import patch, MagicMock, Mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi import FastAPI, Depends
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.api.routes.auth import router as auth_router
|
|
||||||
from app.api.routes.users import router as users_router
|
|
||||||
from app.core.auth import get_password_hash
|
|
||||||
from app.core.database import get_db
|
|
||||||
from app.models.user import User
|
|
||||||
from app.services.auth_service import AuthService, AuthenticationError
|
|
||||||
from app.core.auth import TokenExpiredError, TokenInvalidError
|
|
||||||
|
|
||||||
|
|
||||||
# Mock the get_db dependency
|
|
||||||
@pytest.fixture
|
|
||||||
def override_get_db(db_session):
|
|
||||||
"""Override get_db dependency for testing."""
|
|
||||||
return db_session
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def app(override_get_db):
|
|
||||||
"""Create a FastAPI test application with overridden dependencies."""
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(auth_router, prefix="/auth", tags=["auth"])
|
|
||||||
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
|
|
||||||
|
|
||||||
# Override the get_db dependency
|
|
||||||
app.dependency_overrides[get_db] = lambda: override_get_db
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client(app):
|
|
||||||
"""Create a FastAPI test client."""
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRegisterUser:
|
|
||||||
"""Tests for the register_user endpoint."""
|
|
||||||
|
|
||||||
def test_register_user_success(self, client, monkeypatch, db_session):
|
|
||||||
"""Test successful user registration."""
|
|
||||||
# Mock the service method with a valid complete User object
|
|
||||||
mock_user = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="newuser@example.com",
|
|
||||||
password_hash="hashed_password",
|
|
||||||
first_name="New",
|
|
||||||
last_name="User",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use patch for mocking
|
|
||||||
with patch.object(AuthService, 'create_user', return_value=mock_user):
|
|
||||||
# Test request
|
|
||||||
response = client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "newuser@example.com",
|
|
||||||
"password": "Password123",
|
|
||||||
"first_name": "New",
|
|
||||||
"last_name": "User"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 201
|
|
||||||
data = response.json()
|
|
||||||
assert data["email"] == "newuser@example.com"
|
|
||||||
assert data["first_name"] == "New"
|
|
||||||
assert data["last_name"] == "User"
|
|
||||||
assert "password" not in data
|
|
||||||
|
|
||||||
def test_register_user_duplicate_email(self, client, db_session):
|
|
||||||
"""Test registration with duplicate email."""
|
|
||||||
# Use patch for mocking with a side effect
|
|
||||||
with patch.object(AuthService, 'create_user',
|
|
||||||
side_effect=AuthenticationError("User with this email already exists")):
|
|
||||||
# Test request
|
|
||||||
response = client.post(
|
|
||||||
"/auth/register",
|
|
||||||
json={
|
|
||||||
"email": "existing@example.com",
|
|
||||||
"password": "Password123",
|
|
||||||
"first_name": "Existing",
|
|
||||||
"last_name": "User"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 409
|
|
||||||
assert "already exists" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestLogin:
|
|
||||||
"""Tests for the login endpoint."""
|
|
||||||
|
|
||||||
def test_login_success(self, client, mock_user, db_session):
|
|
||||||
"""Test successful login."""
|
|
||||||
# Ensure mock_user has required attributes
|
|
||||||
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
|
||||||
mock_user.created_at = datetime.now(timezone.utc)
|
|
||||||
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
|
||||||
mock_user.updated_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Create mock tokens
|
|
||||||
mock_tokens = MagicMock(
|
|
||||||
access_token="mock_access_token",
|
|
||||||
refresh_token="mock_refresh_token",
|
|
||||||
token_type="bearer"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use context managers for patching
|
|
||||||
with patch.object(AuthService, 'authenticate_user', return_value=mock_user), \
|
|
||||||
patch.object(AuthService, 'create_tokens', return_value=mock_tokens):
|
|
||||||
|
|
||||||
# Test request
|
|
||||||
response = client.post(
|
|
||||||
"/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "user@example.com",
|
|
||||||
"password": "Password123"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["access_token"] == "mock_access_token"
|
|
||||||
assert data["refresh_token"] == "mock_refresh_token"
|
|
||||||
assert data["token_type"] == "bearer"
|
|
||||||
|
|
||||||
|
|
||||||
def test_login_invalid_credentials_debug(self, client, app):
|
|
||||||
"""Improved test for login with invalid credentials."""
|
|
||||||
# Print response for debugging
|
|
||||||
from app.services.auth_service import AuthService
|
|
||||||
|
|
||||||
# Create a complete mock for AuthService
|
|
||||||
class MockAuthService:
|
|
||||||
@staticmethod
|
|
||||||
def authenticate_user(db, email, password):
|
|
||||||
print(f"Mock called with: {email}, {password}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Replace the entire class with our mock
|
|
||||||
original_service = AuthService
|
|
||||||
try:
|
|
||||||
# Replace with our mock
|
|
||||||
import sys
|
|
||||||
sys.modules['app.services.auth_service'].AuthService = MockAuthService
|
|
||||||
|
|
||||||
# Make the request
|
|
||||||
response = client.post(
|
|
||||||
"/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "user@example.com",
|
|
||||||
"password": "WrongPassword"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print response details for debugging
|
|
||||||
print(f"Response status: {response.status_code}")
|
|
||||||
print(f"Response body: {response.text}")
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "Invalid email or password" in response.json()["detail"]
|
|
||||||
finally:
|
|
||||||
# Restore original service
|
|
||||||
sys.modules['app.services.auth_service'].AuthService = original_service
|
|
||||||
|
|
||||||
|
|
||||||
def test_login_inactive_user(self, client, db_session):
|
|
||||||
"""Test login with inactive user."""
|
|
||||||
# Mock authentication to raise an error
|
|
||||||
with patch.object(AuthService, 'authenticate_user',
|
|
||||||
side_effect=AuthenticationError("User account is inactive")):
|
|
||||||
# Test request
|
|
||||||
response = client.post(
|
|
||||||
"/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "inactive@example.com",
|
|
||||||
"password": "Password123"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "inactive" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestRefreshToken:
|
|
||||||
"""Tests for the refresh_token endpoint."""
|
|
||||||
|
|
||||||
def test_refresh_token_success(self, client, db_session):
|
|
||||||
"""Test successful token refresh."""
|
|
||||||
from app.models.user import User
|
|
||||||
from app.core.auth import get_password_hash
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
# Create a test user
|
|
||||||
test_user = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="refreshtest@example.com",
|
|
||||||
password_hash=get_password_hash("TestPassword123"),
|
|
||||||
first_name="Refresh",
|
|
||||||
last_name="Test",
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
db_session.add(test_user)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Login to get real tokens with a session
|
|
||||||
login_response = client.post(
|
|
||||||
"/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "refreshtest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assert login_response.status_code == 200
|
|
||||||
tokens = login_response.json()
|
|
||||||
|
|
||||||
# Test refresh with real token
|
|
||||||
response = client.post(
|
|
||||||
"/auth/refresh",
|
|
||||||
json={
|
|
||||||
"refresh_token": tokens["refresh_token"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "access_token" in data
|
|
||||||
assert "refresh_token" in data
|
|
||||||
assert data["token_type"] == "bearer"
|
|
||||||
|
|
||||||
def test_refresh_token_expired(self, client, db_session):
|
|
||||||
"""Test refresh with expired token."""
|
|
||||||
from app.api.routes import auth as auth_routes
|
|
||||||
|
|
||||||
# Mock decode_token to raise expired token error
|
|
||||||
with patch.object(auth_routes, 'decode_token',
|
|
||||||
side_effect=TokenExpiredError("Token expired")):
|
|
||||||
# Test request
|
|
||||||
response = client.post(
|
|
||||||
"/auth/refresh",
|
|
||||||
json={
|
|
||||||
"refresh_token": "expired_refresh_token"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 401
|
|
||||||
# Check if it's in the new error format or old detail format
|
|
||||||
response_data = response.json()
|
|
||||||
if "errors" in response_data:
|
|
||||||
assert "expired" in response_data["errors"][0]["message"].lower()
|
|
||||||
else:
|
|
||||||
assert "detail" in response_data
|
|
||||||
assert "expired" in response_data["detail"].lower()
|
|
||||||
|
|
||||||
def test_refresh_token_invalid(self, client, db_session):
|
|
||||||
"""Test refresh with invalid token."""
|
|
||||||
# Mock refresh to raise invalid token error
|
|
||||||
with patch.object(AuthService, 'refresh_tokens',
|
|
||||||
side_effect=TokenInvalidError("Invalid token")):
|
|
||||||
# Test request
|
|
||||||
response = client.post(
|
|
||||||
"/auth/refresh",
|
|
||||||
json={
|
|
||||||
"refresh_token": "invalid_refresh_token"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "Invalid" in response.json()["detail"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestChangePassword:
|
|
||||||
"""Tests for the change_password endpoint."""
|
|
||||||
|
|
||||||
def test_change_password_success(self, client, mock_user, db_session, app):
|
|
||||||
"""Test successful password change."""
|
|
||||||
# Ensure mock_user has required attributes
|
|
||||||
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
|
||||||
mock_user.created_at = datetime.now(timezone.utc)
|
|
||||||
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
|
||||||
mock_user.updated_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Override get_current_user dependency
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
|
||||||
|
|
||||||
# Mock password change to return success
|
|
||||||
with patch.object(AuthService, 'change_password', return_value=True):
|
|
||||||
# Test request (new endpoint)
|
|
||||||
response = client.patch(
|
|
||||||
"/api/v1/users/me/password",
|
|
||||||
json={
|
|
||||||
"current_password": "OldPassword123",
|
|
||||||
"new_password": "NewPassword123"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["success"] is True
|
|
||||||
assert "message" in response.json()
|
|
||||||
|
|
||||||
# Clean up override
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
def test_change_password_incorrect_current_password(self, client, mock_user, db_session, app):
|
|
||||||
"""Test change password with incorrect current password."""
|
|
||||||
# Ensure mock_user has required attributes
|
|
||||||
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
|
||||||
mock_user.created_at = datetime.now(timezone.utc)
|
|
||||||
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
|
||||||
mock_user.updated_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Override get_current_user dependency
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
|
||||||
|
|
||||||
# Mock password change to raise error
|
|
||||||
with patch.object(AuthService, 'change_password',
|
|
||||||
side_effect=AuthenticationError("Current password is incorrect")):
|
|
||||||
# Test request (new endpoint)
|
|
||||||
response = client.patch(
|
|
||||||
"/api/v1/users/me/password",
|
|
||||||
json={
|
|
||||||
"current_password": "WrongPassword",
|
|
||||||
"new_password": "NewPassword123"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions - Now returns standardized error response
|
|
||||||
assert response.status_code == 403
|
|
||||||
# The response has standardized error format
|
|
||||||
data = response.json()
|
|
||||||
assert "detail" in data or "errors" in data
|
|
||||||
|
|
||||||
# Clean up override
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetCurrentUserInfo:
|
|
||||||
"""Tests for the get_current_user_info endpoint."""
|
|
||||||
|
|
||||||
def test_get_current_user_info(self, client, mock_user, app):
|
|
||||||
"""Test getting current user info."""
|
|
||||||
# Ensure mock_user has required attributes
|
|
||||||
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
|
||||||
mock_user.created_at = datetime.now(timezone.utc)
|
|
||||||
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
|
||||||
mock_user.updated_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Override get_current_user dependency
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
|
||||||
|
|
||||||
# Test request
|
|
||||||
response = client.get("/auth/me")
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["email"] == mock_user.email
|
|
||||||
assert data["first_name"] == mock_user.first_name
|
|
||||||
assert data["last_name"] == mock_user.last_name
|
|
||||||
assert "password" not in data
|
|
||||||
|
|
||||||
# Clean up override
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
def test_get_current_user_info_unauthorized(self, client):
|
|
||||||
"""Test getting user info without authentication."""
|
|
||||||
# Test request without authentication
|
|
||||||
response = client.get("/auth/me")
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert response.status_code == 401
|
|
||||||
@@ -1,203 +0,0 @@
|
|||||||
# tests/api/routes/test_rate_limiting.py
|
|
||||||
import os
|
|
||||||
import pytest
|
|
||||||
from fastapi import FastAPI, status
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
from app.api.routes.auth import router as auth_router, limiter
|
|
||||||
from app.api.routes.users import router as users_router
|
|
||||||
from app.core.database import get_db
|
|
||||||
|
|
||||||
# Skip all rate limiting tests when IS_TEST=True (rate limits are disabled in test mode)
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.getenv("IS_TEST", "False") == "True",
|
|
||||||
reason="Rate limits are disabled in test mode (RATE_MULTIPLIER=100)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Mock the get_db dependency
|
|
||||||
@pytest.fixture
|
|
||||||
def override_get_db():
|
|
||||||
"""Override get_db dependency for testing."""
|
|
||||||
mock_db = MagicMock()
|
|
||||||
return mock_db
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def app(override_get_db):
|
|
||||||
"""Create a FastAPI test application with rate limiting."""
|
|
||||||
from slowapi import _rate_limit_exceeded_handler
|
|
||||||
from slowapi.errors import RateLimitExceeded
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
app.state.limiter = limiter
|
|
||||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
||||||
app.include_router(auth_router, prefix="/auth", tags=["auth"])
|
|
||||||
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
|
|
||||||
|
|
||||||
# Override the get_db dependency
|
|
||||||
app.dependency_overrides[get_db] = lambda: override_get_db
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client(app):
|
|
||||||
"""Create a FastAPI test client."""
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRegisterRateLimiting:
|
|
||||||
"""Tests for rate limiting on /register endpoint"""
|
|
||||||
|
|
||||||
def test_register_rate_limit_blocks_over_limit(self, client):
|
|
||||||
"""Test that requests over rate limit are blocked"""
|
|
||||||
from app.services.auth_service import AuthService
|
|
||||||
from app.models.user import User
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
mock_user = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="test@example.com",
|
|
||||||
password_hash="hashed",
|
|
||||||
first_name="Test",
|
|
||||||
last_name="User",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(AuthService, 'create_user', return_value=mock_user):
|
|
||||||
user_data = {
|
|
||||||
"email": f"test{uuid.uuid4()}@example.com",
|
|
||||||
"password": "TestPassword123!",
|
|
||||||
"first_name": "Test",
|
|
||||||
"last_name": "User"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make 6 requests (limit is 5/minute)
|
|
||||||
responses = []
|
|
||||||
for i in range(6):
|
|
||||||
response = client.post("/auth/register", json=user_data)
|
|
||||||
responses.append(response)
|
|
||||||
|
|
||||||
# Last request should be rate limited
|
|
||||||
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoginRateLimiting:
|
|
||||||
"""Tests for rate limiting on /login endpoint"""
|
|
||||||
|
|
||||||
def test_login_rate_limit_blocks_over_limit(self, client):
|
|
||||||
"""Test that login requests over rate limit are blocked"""
|
|
||||||
from app.services.auth_service import AuthService
|
|
||||||
|
|
||||||
with patch.object(AuthService, 'authenticate_user', return_value=None):
|
|
||||||
login_data = {
|
|
||||||
"email": "test@example.com",
|
|
||||||
"password": "wrong_password"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make 11 requests (limit is 10/minute)
|
|
||||||
responses = []
|
|
||||||
for i in range(11):
|
|
||||||
response = client.post("/auth/login", json=login_data)
|
|
||||||
responses.append(response)
|
|
||||||
|
|
||||||
# Last request should be rate limited
|
|
||||||
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
|
||||||
|
|
||||||
|
|
||||||
class TestRefreshTokenRateLimiting:
|
|
||||||
"""Tests for rate limiting on /refresh endpoint"""
|
|
||||||
|
|
||||||
def test_refresh_rate_limit_blocks_over_limit(self, client):
|
|
||||||
"""Test that refresh requests over rate limit are blocked"""
|
|
||||||
from app.services.auth_service import AuthService
|
|
||||||
from app.core.auth import TokenInvalidError
|
|
||||||
|
|
||||||
with patch.object(AuthService, 'refresh_tokens', side_effect=TokenInvalidError("Invalid")):
|
|
||||||
refresh_data = {
|
|
||||||
"refresh_token": "invalid_token"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make 31 requests (limit is 30/minute)
|
|
||||||
responses = []
|
|
||||||
for i in range(31):
|
|
||||||
response = client.post("/auth/refresh", json=refresh_data)
|
|
||||||
responses.append(response)
|
|
||||||
|
|
||||||
# Last request should be rate limited
|
|
||||||
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
|
||||||
|
|
||||||
|
|
||||||
class TestChangePasswordRateLimiting:
|
|
||||||
"""Tests for rate limiting on /change-password endpoint"""
|
|
||||||
|
|
||||||
def test_change_password_rate_limit_blocks_over_limit(self, client):
|
|
||||||
"""Test that change password requests over rate limit are blocked"""
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
|
||||||
from app.models.user import User
|
|
||||||
from app.services.auth_service import AuthService, AuthenticationError
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
# Mock current user
|
|
||||||
mock_user = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="test@example.com",
|
|
||||||
password_hash="hashed",
|
|
||||||
first_name="Test",
|
|
||||||
last_name="User",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Override get_current_user dependency in the app
|
|
||||||
test_app = client.app
|
|
||||||
test_app.dependency_overrides[get_current_user] = lambda: mock_user
|
|
||||||
|
|
||||||
with patch.object(AuthService, 'change_password', side_effect=AuthenticationError("Invalid password")):
|
|
||||||
password_data = {
|
|
||||||
"current_password": "wrong_password",
|
|
||||||
"new_password": "NewPassword123!"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make 6 requests (limit is 5/minute) - using new endpoint
|
|
||||||
responses = []
|
|
||||||
for i in range(6):
|
|
||||||
response = client.patch("/api/v1/users/me/password", json=password_data)
|
|
||||||
responses.append(response)
|
|
||||||
|
|
||||||
# Last request should be rate limited
|
|
||||||
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
|
||||||
|
|
||||||
# Clean up override
|
|
||||||
test_app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class TestRateLimitErrorResponse:
|
|
||||||
"""Tests for rate limit error response format"""
|
|
||||||
|
|
||||||
def test_rate_limit_error_response_format(self, client):
|
|
||||||
"""Test that rate limit error has correct format"""
|
|
||||||
from app.services.auth_service import AuthService
|
|
||||||
|
|
||||||
with patch.object(AuthService, 'authenticate_user', return_value=None):
|
|
||||||
login_data = {
|
|
||||||
"email": "test@example.com",
|
|
||||||
"password": "password"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Exceed rate limit
|
|
||||||
for i in range(11):
|
|
||||||
response = client.post("/auth/login", json=login_data)
|
|
||||||
|
|
||||||
# Check error response
|
|
||||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
|
||||||
assert "detail" in response.json() or "error" in response.json()
|
|
||||||
@@ -1,487 +0,0 @@
|
|||||||
# tests/api/routes/test_users.py
|
|
||||||
"""
|
|
||||||
Tests for user management endpoints.
|
|
||||||
"""
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.api.routes.users import router as users_router
|
|
||||||
from app.core.database import get_db
|
|
||||||
from app.models.user import User
|
|
||||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def override_get_db(db_session):
|
|
||||||
"""Override get_db dependency for testing."""
|
|
||||||
return db_session
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def app(override_get_db):
|
|
||||||
"""Create a FastAPI test application."""
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
|
|
||||||
|
|
||||||
# Override the get_db dependency
|
|
||||||
app.dependency_overrides[get_db] = lambda: override_get_db
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client(app):
|
|
||||||
"""Create a FastAPI test client."""
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def regular_user():
|
|
||||||
"""Create a mock regular user."""
|
|
||||||
return User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="regular@example.com",
|
|
||||||
password_hash="hashed_password",
|
|
||||||
first_name="Regular",
|
|
||||||
last_name="User",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def super_user():
|
|
||||||
"""Create a mock superuser."""
|
|
||||||
return User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="admin@example.com",
|
|
||||||
password_hash="hashed_password",
|
|
||||||
first_name="Admin",
|
|
||||||
last_name="User",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=True,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestListUsers:
|
|
||||||
"""Tests for the list_users endpoint."""
|
|
||||||
|
|
||||||
def test_list_users_as_superuser(self, client, app, super_user, regular_user, db_session):
|
|
||||||
"""Test that superusers can list all users."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
# Override auth dependency
|
|
||||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
|
||||||
|
|
||||||
# Mock user_crud to return test data
|
|
||||||
mock_users = [regular_user for _ in range(3)]
|
|
||||||
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users, 3)):
|
|
||||||
response = client.get("/api/v1/users?page=1&limit=20")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "data" in data
|
|
||||||
assert "pagination" in data
|
|
||||||
assert len(data["data"]) == 3
|
|
||||||
assert data["pagination"]["total"] == 3
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_superuser in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_superuser]
|
|
||||||
|
|
||||||
def test_list_users_pagination(self, client, app, super_user, regular_user, db_session):
|
|
||||||
"""Test pagination parameters for list users."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
|
||||||
|
|
||||||
# Mock user_crud
|
|
||||||
mock_users = [regular_user for _ in range(10)]
|
|
||||||
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users[:5], 10)):
|
|
||||||
response = client.get("/api/v1/users?page=1&limit=5")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["pagination"]["page"] == 1
|
|
||||||
assert data["pagination"]["page_size"] == 5
|
|
||||||
assert data["pagination"]["total"] == 10
|
|
||||||
assert data["pagination"]["total_pages"] == 2
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_superuser in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_superuser]
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetCurrentUserProfile:
|
|
||||||
"""Tests for the get_current_user_profile endpoint."""
|
|
||||||
|
|
||||||
def test_get_current_user_profile(self, client, app, regular_user):
|
|
||||||
"""Test getting current user's profile."""
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
|
||||||
|
|
||||||
response = client.get("/api/v1/users/me")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["email"] == regular_user.email
|
|
||||||
assert data["first_name"] == regular_user.first_name
|
|
||||||
assert data["last_name"] == regular_user.last_name
|
|
||||||
assert "password" not in data
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
|
|
||||||
class TestUpdateCurrentUser:
|
|
||||||
"""Tests for the update_current_user endpoint."""
|
|
||||||
|
|
||||||
def test_update_current_user_success(self, client, app, regular_user, db_session):
|
|
||||||
"""Test successful profile update."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
|
||||||
|
|
||||||
updated_user = User(
|
|
||||||
id=regular_user.id,
|
|
||||||
email=regular_user.email,
|
|
||||||
password_hash=regular_user.password_hash,
|
|
||||||
first_name="Updated",
|
|
||||||
last_name="Name",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=regular_user.created_at,
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'update', return_value=updated_user):
|
|
||||||
response = client.patch(
|
|
||||||
"/api/v1/users/me",
|
|
||||||
json={"first_name": "Updated", "last_name": "Name"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["first_name"] == "Updated"
|
|
||||||
assert data["last_name"] == "Name"
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
def test_update_current_user_extra_fields_ignored(self, client, app, regular_user, db_session):
|
|
||||||
"""Test that extra fields like is_superuser are ignored by schema validation."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
|
||||||
|
|
||||||
# Create updated user without is_superuser changed
|
|
||||||
updated_user = User(
|
|
||||||
id=regular_user.id,
|
|
||||||
email=regular_user.email,
|
|
||||||
password_hash=regular_user.password_hash,
|
|
||||||
first_name="Updated",
|
|
||||||
last_name=regular_user.last_name,
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False, # Should remain False
|
|
||||||
created_at=regular_user.created_at,
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'update', return_value=updated_user):
|
|
||||||
response = client.patch(
|
|
||||||
"/api/v1/users/me",
|
|
||||||
json={"first_name": "Updated", "is_superuser": True} # is_superuser will be ignored
|
|
||||||
)
|
|
||||||
|
|
||||||
# Request should succeed but is_superuser should be unchanged
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["is_superuser"] is False
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetUserById:
|
|
||||||
"""Tests for the get_user_by_id endpoint."""
|
|
||||||
|
|
||||||
def test_get_own_profile(self, client, app, regular_user, db_session):
|
|
||||||
"""Test that users can get their own profile."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'get', return_value=regular_user):
|
|
||||||
response = client.get(f"/api/v1/users/{regular_user.id}")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["email"] == regular_user.email
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
def test_get_other_user_as_regular_user(self, client, app, regular_user):
|
|
||||||
"""Test that regular users cannot view other users."""
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
|
||||||
|
|
||||||
other_user_id = uuid.uuid4()
|
|
||||||
response = client.get(f"/api/v1/users/{other_user_id}")
|
|
||||||
|
|
||||||
assert response.status_code == 403
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
def test_get_other_user_as_superuser(self, client, app, super_user, db_session):
|
|
||||||
"""Test that superusers can view any user."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
|
||||||
|
|
||||||
other_user = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="other@example.com",
|
|
||||||
password_hash="hashed",
|
|
||||||
first_name="Other",
|
|
||||||
last_name="User",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'get', return_value=other_user):
|
|
||||||
response = client.get(f"/api/v1/users/{other_user.id}")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["email"] == other_user.email
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
def test_get_nonexistent_user(self, client, app, super_user, db_session):
|
|
||||||
"""Test getting a user that doesn't exist."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'get', return_value=None):
|
|
||||||
response = client.get(f"/api/v1/users/{uuid.uuid4()}")
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
|
|
||||||
class TestUpdateUser:
|
|
||||||
"""Tests for the update_user endpoint."""
|
|
||||||
|
|
||||||
def test_update_own_profile(self, client, app, regular_user, db_session):
|
|
||||||
"""Test that users can update their own profile."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
|
||||||
|
|
||||||
updated_user = User(
|
|
||||||
id=regular_user.id,
|
|
||||||
email=regular_user.email,
|
|
||||||
password_hash=regular_user.password_hash,
|
|
||||||
first_name="NewName",
|
|
||||||
last_name=regular_user.last_name,
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=regular_user.created_at,
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'get', return_value=regular_user), \
|
|
||||||
patch.object(user_crud, 'update', return_value=updated_user):
|
|
||||||
response = client.patch(
|
|
||||||
f"/api/v1/users/{regular_user.id}",
|
|
||||||
json={"first_name": "NewName"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["first_name"] == "NewName"
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
def test_update_other_user_as_regular_user(self, client, app, regular_user):
|
|
||||||
"""Test that regular users cannot update other users."""
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
|
||||||
|
|
||||||
other_user_id = uuid.uuid4()
|
|
||||||
response = client.patch(
|
|
||||||
f"/api/v1/users/{other_user_id}",
|
|
||||||
json={"first_name": "NewName"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 403
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
def test_user_schema_ignores_extra_fields(self, client, app, regular_user, db_session):
|
|
||||||
"""Test that UserUpdate schema ignores extra fields like is_superuser."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
|
||||||
|
|
||||||
# Updated user with is_superuser unchanged
|
|
||||||
updated_user = User(
|
|
||||||
id=regular_user.id,
|
|
||||||
email=regular_user.email,
|
|
||||||
password_hash=regular_user.password_hash,
|
|
||||||
first_name="Changed",
|
|
||||||
last_name=regular_user.last_name,
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False, # Should remain False
|
|
||||||
created_at=regular_user.created_at,
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'get', return_value=regular_user), \
|
|
||||||
patch.object(user_crud, 'update', return_value=updated_user):
|
|
||||||
response = client.patch(
|
|
||||||
f"/api/v1/users/{regular_user.id}",
|
|
||||||
json={"first_name": "Changed", "is_superuser": True} # is_superuser ignored
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should succeed, extra field is ignored
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["is_superuser"] is False
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
def test_superuser_can_update_any_user(self, client, app, super_user, db_session):
|
|
||||||
"""Test that superusers can update any user."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
|
||||||
|
|
||||||
target_user = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="target@example.com",
|
|
||||||
password_hash="hashed",
|
|
||||||
first_name="Target",
|
|
||||||
last_name="User",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_user = User(
|
|
||||||
id=target_user.id,
|
|
||||||
email=target_user.email,
|
|
||||||
password_hash=target_user.password_hash,
|
|
||||||
first_name="Updated",
|
|
||||||
last_name=target_user.last_name,
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=target_user.created_at,
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'get', return_value=target_user), \
|
|
||||||
patch.object(user_crud, 'update', return_value=updated_user):
|
|
||||||
response = client.patch(
|
|
||||||
f"/api/v1/users/{target_user.id}",
|
|
||||||
json={"first_name": "Updated"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["first_name"] == "Updated"
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_user in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_user]
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeleteUser:
|
|
||||||
"""Tests for the delete_user endpoint."""
|
|
||||||
|
|
||||||
def test_delete_user_as_superuser(self, client, app, super_user, db_session):
|
|
||||||
"""Test that superusers can delete users."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
|
||||||
|
|
||||||
target_user = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="target@example.com",
|
|
||||||
password_hash="hashed",
|
|
||||||
first_name="Target",
|
|
||||||
last_name="User",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'get', return_value=target_user), \
|
|
||||||
patch.object(user_crud, 'remove', return_value=target_user):
|
|
||||||
response = client.delete(f"/api/v1/users/{target_user.id}")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert "deleted successfully" in data["message"]
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_superuser in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_superuser]
|
|
||||||
|
|
||||||
def test_delete_nonexistent_user(self, client, app, super_user, db_session):
|
|
||||||
"""Test deleting a user that doesn't exist."""
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
|
|
||||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
|
||||||
|
|
||||||
with patch.object(user_crud, 'get', return_value=None):
|
|
||||||
response = client.delete(f"/api/v1/users/{uuid.uuid4()}")
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_superuser in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_superuser]
|
|
||||||
|
|
||||||
def test_cannot_delete_self(self, client, app, super_user, db_session):
|
|
||||||
"""Test that users cannot delete their own account."""
|
|
||||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
|
||||||
|
|
||||||
response = client.delete(f"/api/v1/users/{super_user.id}")
|
|
||||||
|
|
||||||
assert response.status_code == 403
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if get_current_superuser in app.dependency_overrides:
|
|
||||||
del app.dependency_overrides[get_current_superuser]
|
|
||||||
@@ -332,9 +332,9 @@ class TestPasswordResetConfirm:
|
|||||||
"""Test password reset confirmation with database error."""
|
"""Test password reset confirmation with database error."""
|
||||||
token = create_password_reset_token(async_test_user.email)
|
token = create_password_reset_token(async_test_user.email)
|
||||||
|
|
||||||
# Mock the password update to raise an exception
|
# Mock the database commit to raise an exception
|
||||||
with patch('app.api.routes.auth.user_crud.update') as mock_update:
|
with patch('app.api.routes.auth.user_crud.get_by_email') as mock_get:
|
||||||
mock_update.side_effect = Exception("Database error")
|
mock_get.side_effect = Exception("Database error")
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/v1/auth/password-reset/confirm",
|
"/api/v1/auth/password-reset/confirm",
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ from app.main import app
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
"""Create a FastAPI test client for the main app."""
|
"""Create a FastAPI test client for the main app."""
|
||||||
# Mock get_db to avoid database connection issues
|
# Mock get_async_db to avoid database connection issues
|
||||||
with patch("app.main.get_db") as mock_get_db:
|
with patch("app.core.database_async.get_async_db") as mock_get_db:
|
||||||
def mock_session_generator():
|
async def mock_session_generator():
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, AsyncMock
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session.execute.return_value = None
|
mock_session.execute = AsyncMock(return_value=None)
|
||||||
mock_session.close.return_value = None
|
mock_session.close = AsyncMock(return_value=None)
|
||||||
yield mock_session
|
yield mock_session
|
||||||
|
|
||||||
mock_get_db.side_effect = lambda: mock_session_generator()
|
mock_get_db.side_effect = lambda: mock_session_generator()
|
||||||
|
|||||||
@@ -1,421 +0,0 @@
|
|||||||
"""
|
|
||||||
Integration tests for session management.
|
|
||||||
|
|
||||||
Tests the critical per-device logout functionality.
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.main import app
|
|
||||||
from app.core.database import get_db
|
|
||||||
from app.models.user import User
|
|
||||||
from app.core.auth import get_password_hash
|
|
||||||
from app.utils.test_utils import setup_test_db, teardown_test_db
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def test_db_session():
|
|
||||||
"""Create test database session."""
|
|
||||||
test_engine, TestingSessionLocal = setup_test_db()
|
|
||||||
with TestingSessionLocal() as session:
|
|
||||||
yield session
|
|
||||||
teardown_test_db(test_engine)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def client(test_db_session):
|
|
||||||
"""Create test client with test database."""
|
|
||||||
def override_get_db():
|
|
||||||
try:
|
|
||||||
yield test_db_session
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
app.dependency_overrides[get_db] = override_get_db
|
|
||||||
with TestClient(app) as test_client:
|
|
||||||
yield test_client
|
|
||||||
app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_user(test_db_session):
|
|
||||||
"""Create a test user."""
|
|
||||||
user = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="sessiontest@example.com",
|
|
||||||
password_hash=get_password_hash("TestPassword123"),
|
|
||||||
first_name="Session",
|
|
||||||
last_name="Test",
|
|
||||||
phone_number="+1234567890",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
preferences=None,
|
|
||||||
)
|
|
||||||
test_db_session.add(user)
|
|
||||||
test_db_session.commit()
|
|
||||||
test_db_session.refresh(user)
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
class TestMultiDeviceLogin:
|
|
||||||
"""Test multi-device login scenarios."""
|
|
||||||
|
|
||||||
def test_login_from_multiple_devices(self, client, test_user):
|
|
||||||
"""Test that user can login from multiple devices simultaneously."""
|
|
||||||
# Login from PC
|
|
||||||
pc_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
},
|
|
||||||
headers={"X-Device-Id": "pc-device-001"}
|
|
||||||
)
|
|
||||||
assert pc_response.status_code == 200
|
|
||||||
pc_tokens = pc_response.json()
|
|
||||||
assert "access_token" in pc_tokens
|
|
||||||
assert "refresh_token" in pc_tokens
|
|
||||||
pc_refresh = pc_tokens["refresh_token"]
|
|
||||||
|
|
||||||
# Login from Phone
|
|
||||||
phone_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
},
|
|
||||||
headers={"X-Device-Id": "phone-device-001"}
|
|
||||||
)
|
|
||||||
assert phone_response.status_code == 200
|
|
||||||
phone_tokens = phone_response.json()
|
|
||||||
assert "access_token" in phone_tokens
|
|
||||||
assert "refresh_token" in phone_tokens
|
|
||||||
phone_refresh = phone_tokens["refresh_token"]
|
|
||||||
|
|
||||||
# Verify both tokens are different
|
|
||||||
assert pc_refresh != phone_refresh
|
|
||||||
|
|
||||||
# Both should be able to access protected endpoints
|
|
||||||
pc_me = client.get(
|
|
||||||
"/api/v1/auth/me",
|
|
||||||
headers={"Authorization": f"Bearer {pc_tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
assert pc_me.status_code == 200
|
|
||||||
|
|
||||||
phone_me = client.get(
|
|
||||||
"/api/v1/auth/me",
|
|
||||||
headers={"Authorization": f"Bearer {phone_tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
assert phone_me.status_code == 200
|
|
||||||
|
|
||||||
def test_logout_from_one_device_does_not_affect_other(self, client, test_user):
|
|
||||||
"""
|
|
||||||
CRITICAL TEST: Logout from PC should NOT logout from Phone.
|
|
||||||
|
|
||||||
This is the main requirement for session management.
|
|
||||||
"""
|
|
||||||
# Login from PC
|
|
||||||
pc_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
},
|
|
||||||
headers={"X-Device-Id": "pc-device-001"}
|
|
||||||
)
|
|
||||||
assert pc_response.status_code == 200
|
|
||||||
pc_tokens = pc_response.json()
|
|
||||||
pc_access = pc_tokens["access_token"]
|
|
||||||
pc_refresh = pc_tokens["refresh_token"]
|
|
||||||
|
|
||||||
# Login from Phone
|
|
||||||
phone_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
},
|
|
||||||
headers={"X-Device-Id": "phone-device-001"}
|
|
||||||
)
|
|
||||||
assert phone_response.status_code == 200
|
|
||||||
phone_tokens = phone_response.json()
|
|
||||||
phone_access = phone_tokens["access_token"]
|
|
||||||
phone_refresh = phone_tokens["refresh_token"]
|
|
||||||
|
|
||||||
# Logout from PC
|
|
||||||
logout_response = client.post(
|
|
||||||
"/api/v1/auth/logout",
|
|
||||||
json={"refresh_token": pc_refresh},
|
|
||||||
headers={"Authorization": f"Bearer {pc_access}"}
|
|
||||||
)
|
|
||||||
assert logout_response.status_code == 200
|
|
||||||
assert logout_response.json()["success"] == True
|
|
||||||
|
|
||||||
# PC refresh should fail (logged out)
|
|
||||||
pc_refresh_response = client.post(
|
|
||||||
"/api/v1/auth/refresh",
|
|
||||||
json={"refresh_token": pc_refresh}
|
|
||||||
)
|
|
||||||
assert pc_refresh_response.status_code == 401
|
|
||||||
response_data = pc_refresh_response.json()
|
|
||||||
assert "revoked" in response_data["errors"][0]["message"].lower()
|
|
||||||
|
|
||||||
# Phone refresh should still work ✅ THIS IS THE CRITICAL ASSERTION
|
|
||||||
phone_refresh_response = client.post(
|
|
||||||
"/api/v1/auth/refresh",
|
|
||||||
json={"refresh_token": phone_refresh}
|
|
||||||
)
|
|
||||||
assert phone_refresh_response.status_code == 200
|
|
||||||
new_phone_tokens = phone_refresh_response.json()
|
|
||||||
assert "access_token" in new_phone_tokens
|
|
||||||
|
|
||||||
# Phone can still access protected endpoints
|
|
||||||
phone_me = client.get(
|
|
||||||
"/api/v1/auth/me",
|
|
||||||
headers={"Authorization": f"Bearer {new_phone_tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
assert phone_me.status_code == 200
|
|
||||||
assert phone_me.json()["email"] == "sessiontest@example.com"
|
|
||||||
|
|
||||||
def test_logout_all_devices(self, client, test_user):
|
|
||||||
"""Test logging out from all devices simultaneously."""
|
|
||||||
# Login from 3 devices
|
|
||||||
devices = []
|
|
||||||
for i, device_name in enumerate(["pc", "phone", "tablet"]):
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
},
|
|
||||||
headers={"X-Device-Id": f"{device_name}-device-00{i}"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
tokens = response.json()
|
|
||||||
devices.append({
|
|
||||||
"name": device_name,
|
|
||||||
"access": tokens["access_token"],
|
|
||||||
"refresh": tokens["refresh_token"]
|
|
||||||
})
|
|
||||||
|
|
||||||
# Logout from all devices using first device's access token
|
|
||||||
logout_all_response = client.post(
|
|
||||||
"/api/v1/auth/logout-all",
|
|
||||||
headers={"Authorization": f"Bearer {devices[0]['access']}"}
|
|
||||||
)
|
|
||||||
assert logout_all_response.status_code == 200
|
|
||||||
assert "3" in logout_all_response.json()["message"] # 3 sessions terminated
|
|
||||||
|
|
||||||
# All refresh tokens should now fail
|
|
||||||
for device in devices:
|
|
||||||
refresh_response = client.post(
|
|
||||||
"/api/v1/auth/refresh",
|
|
||||||
json={"refresh_token": device["refresh"]}
|
|
||||||
)
|
|
||||||
assert refresh_response.status_code == 401
|
|
||||||
|
|
||||||
def test_list_active_sessions(self, client, test_user):
|
|
||||||
"""Test listing active sessions."""
|
|
||||||
# Login from 2 devices
|
|
||||||
pc_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
},
|
|
||||||
headers={"X-Device-Id": "pc-device-001"}
|
|
||||||
)
|
|
||||||
pc_tokens = pc_response.json()
|
|
||||||
|
|
||||||
phone_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
},
|
|
||||||
headers={"X-Device-Id": "phone-device-001"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# List sessions
|
|
||||||
sessions_response = client.get(
|
|
||||||
"/api/v1/sessions/me",
|
|
||||||
headers={"Authorization": f"Bearer {pc_tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
assert sessions_response.status_code == 200
|
|
||||||
sessions_data = sessions_response.json()
|
|
||||||
assert sessions_data["total"] == 2
|
|
||||||
assert len(sessions_data["sessions"]) == 2
|
|
||||||
|
|
||||||
# Check session details
|
|
||||||
session = sessions_data["sessions"][0]
|
|
||||||
assert "device_name" in session
|
|
||||||
assert "ip_address" in session
|
|
||||||
assert "last_used_at" in session
|
|
||||||
assert "created_at" in session
|
|
||||||
|
|
||||||
def test_revoke_specific_session(self, client, test_user):
|
|
||||||
"""Test revoking a specific session by ID."""
|
|
||||||
# Login from 2 devices
|
|
||||||
pc_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
},
|
|
||||||
headers={"X-Device-Id": "pc-device-001"}
|
|
||||||
)
|
|
||||||
pc_tokens = pc_response.json()
|
|
||||||
|
|
||||||
phone_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
},
|
|
||||||
headers={"X-Device-Id": "phone-device-001"}
|
|
||||||
)
|
|
||||||
phone_tokens = phone_response.json()
|
|
||||||
|
|
||||||
# List sessions to get IDs
|
|
||||||
sessions_response = client.get(
|
|
||||||
"/api/v1/sessions/me",
|
|
||||||
headers={"Authorization": f"Bearer {pc_tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
sessions = sessions_response.json()["sessions"]
|
|
||||||
|
|
||||||
# Find the phone session by device_id
|
|
||||||
phone_session = next((s for s in sessions if s["device_id"] == "phone-device-001"), None)
|
|
||||||
assert phone_session is not None, "Phone session not found in session list"
|
|
||||||
session_id_to_revoke = phone_session["id"]
|
|
||||||
revoke_response = client.delete(
|
|
||||||
f"/api/v1/sessions/{session_id_to_revoke}",
|
|
||||||
headers={"Authorization": f"Bearer {pc_tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
assert revoke_response.status_code == 200
|
|
||||||
|
|
||||||
# Phone refresh should fail
|
|
||||||
phone_refresh_response = client.post(
|
|
||||||
"/api/v1/auth/refresh",
|
|
||||||
json={"refresh_token": phone_tokens["refresh_token"]}
|
|
||||||
)
|
|
||||||
assert phone_refresh_response.status_code == 401
|
|
||||||
|
|
||||||
# PC refresh should still work
|
|
||||||
pc_refresh_response = client.post(
|
|
||||||
"/api/v1/auth/refresh",
|
|
||||||
json={"refresh_token": pc_tokens["refresh_token"]}
|
|
||||||
)
|
|
||||||
assert pc_refresh_response.status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
class TestSessionEdgeCases:
|
|
||||||
"""Test edge cases and error scenarios."""
|
|
||||||
|
|
||||||
def test_logout_with_invalid_refresh_token(self, client, test_user):
|
|
||||||
"""Test logout with invalid refresh token."""
|
|
||||||
# Login first
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokens = login_response.json()
|
|
||||||
|
|
||||||
# Try to logout with invalid refresh token
|
|
||||||
logout_response = client.post(
|
|
||||||
"/api/v1/auth/logout",
|
|
||||||
json={"refresh_token": "invalid_token"},
|
|
||||||
headers={"Authorization": f"Bearer {tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
# Should still return success (idempotent)
|
|
||||||
assert logout_response.status_code == 200
|
|
||||||
|
|
||||||
def test_refresh_with_deactivated_session(self, client, test_user):
|
|
||||||
"""Test refresh after session has been deactivated."""
|
|
||||||
# Login
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={
|
|
||||||
"email": "sessiontest@example.com",
|
|
||||||
"password": "TestPassword123"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokens = login_response.json()
|
|
||||||
|
|
||||||
# Logout
|
|
||||||
client.post(
|
|
||||||
"/api/v1/auth/logout",
|
|
||||||
json={"refresh_token": tokens["refresh_token"]},
|
|
||||||
headers={"Authorization": f"Bearer {tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to refresh with deactivated session
|
|
||||||
refresh_response = client.post(
|
|
||||||
"/api/v1/auth/refresh",
|
|
||||||
json={"refresh_token": tokens["refresh_token"]}
|
|
||||||
)
|
|
||||||
assert refresh_response.status_code == 401
|
|
||||||
response_data = refresh_response.json()
|
|
||||||
assert "revoked" in response_data["errors"][0]["message"].lower()
|
|
||||||
|
|
||||||
def test_cannot_revoke_other_users_session(self, client, test_db_session):
|
|
||||||
"""Test that users cannot revoke other users' sessions."""
|
|
||||||
# Create two users
|
|
||||||
user1 = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="user1@example.com",
|
|
||||||
password_hash=get_password_hash("TestPassword123"),
|
|
||||||
first_name="User",
|
|
||||||
last_name="One",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
)
|
|
||||||
user2 = User(
|
|
||||||
id=uuid.uuid4(),
|
|
||||||
email="user2@example.com",
|
|
||||||
password_hash=get_password_hash("TestPassword123"),
|
|
||||||
first_name="User",
|
|
||||||
last_name="Two",
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
)
|
|
||||||
test_db_session.add_all([user1, user2])
|
|
||||||
test_db_session.commit()
|
|
||||||
|
|
||||||
# User1 login
|
|
||||||
user1_login = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={"email": "user1@example.com", "password": "TestPassword123"}
|
|
||||||
)
|
|
||||||
user1_tokens = user1_login.json()
|
|
||||||
|
|
||||||
# User2 login
|
|
||||||
user2_login = client.post(
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
json={"email": "user2@example.com", "password": "TestPassword123"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# User1 gets their sessions
|
|
||||||
user1_sessions = client.get(
|
|
||||||
"/api/v1/sessions/me",
|
|
||||||
headers={"Authorization": f"Bearer {user1_tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
user1_session_id = user1_sessions.json()["sessions"][0]["id"]
|
|
||||||
|
|
||||||
# User2 lists their sessions
|
|
||||||
user2_sessions = client.get(
|
|
||||||
"/api/v1/sessions/me",
|
|
||||||
headers={"Authorization": f"Bearer {user2_login.json()['access_token']}"}
|
|
||||||
)
|
|
||||||
user2_session_id = user2_sessions.json()["sessions"][0]["id"]
|
|
||||||
|
|
||||||
# User1 tries to revoke User2's session (should fail)
|
|
||||||
revoke_response = client.delete(
|
|
||||||
f"/api/v1/sessions/{user2_session_id}",
|
|
||||||
headers={"Authorization": f"Bearer {user1_tokens['access_token']}"}
|
|
||||||
)
|
|
||||||
assert revoke_response.status_code == 403
|
|
||||||
@@ -60,19 +60,22 @@ class TestListUsers:
|
|||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_users_pagination(self, client, async_test_superuser, test_db):
|
async def test_list_users_pagination(self, client, async_test_superuser, async_test_db):
|
||||||
"""Test pagination works correctly."""
|
"""Test pagination works correctly."""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Create multiple users
|
# Create multiple users
|
||||||
for i in range(15):
|
async with AsyncTestingSessionLocal() as session:
|
||||||
user = User(
|
for i in range(15):
|
||||||
email=f"paguser{i}@example.com",
|
user = User(
|
||||||
password_hash="hash",
|
email=f"paguser{i}@example.com",
|
||||||
first_name=f"PagUser{i}",
|
password_hash="hash",
|
||||||
is_active=True,
|
first_name=f"PagUser{i}",
|
||||||
is_superuser=False
|
is_active=True,
|
||||||
)
|
is_superuser=False
|
||||||
test_db.add(user)
|
)
|
||||||
test_db.commit()
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||||
|
|
||||||
@@ -85,25 +88,28 @@ class TestListUsers:
|
|||||||
assert data["pagination"]["total"] >= 15
|
assert data["pagination"]["total"] >= 15
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_users_filter_active(self, client, async_test_superuser, test_db):
|
async def test_list_users_filter_active(self, client, async_test_superuser, async_test_db):
|
||||||
"""Test filtering by active status."""
|
"""Test filtering by active status."""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Create active and inactive users
|
# Create active and inactive users
|
||||||
active_user = User(
|
async with AsyncTestingSessionLocal() as session:
|
||||||
email="activefilter@example.com",
|
active_user = User(
|
||||||
password_hash="hash",
|
email="activefilter@example.com",
|
||||||
first_name="Active",
|
password_hash="hash",
|
||||||
is_active=True,
|
first_name="Active",
|
||||||
is_superuser=False
|
is_active=True,
|
||||||
)
|
is_superuser=False
|
||||||
inactive_user = User(
|
)
|
||||||
email="inactivefilter@example.com",
|
inactive_user = User(
|
||||||
password_hash="hash",
|
email="inactivefilter@example.com",
|
||||||
first_name="Inactive",
|
password_hash="hash",
|
||||||
is_active=False,
|
first_name="Inactive",
|
||||||
is_superuser=False
|
is_active=False,
|
||||||
)
|
is_superuser=False
|
||||||
test_db.add_all([active_user, inactive_user])
|
)
|
||||||
test_db.commit()
|
session.add_all([active_user, inactive_user])
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||||
|
|
||||||
@@ -168,7 +174,7 @@ class TestUpdateCurrentUser:
|
|||||||
"""Tests for PATCH /users/me endpoint."""
|
"""Tests for PATCH /users/me endpoint."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_own_profile(self, client, async_test_user, test_db):
|
async def test_update_own_profile(self, client, async_test_user):
|
||||||
"""Test updating own profile."""
|
"""Test updating own profile."""
|
||||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123")
|
||||||
|
|
||||||
@@ -183,10 +189,6 @@ class TestUpdateCurrentUser:
|
|||||||
assert data["first_name"] == "Updated"
|
assert data["first_name"] == "Updated"
|
||||||
assert data["last_name"] == "Name"
|
assert data["last_name"] == "Name"
|
||||||
|
|
||||||
# Verify in database
|
|
||||||
test_db.refresh(async_test_user)
|
|
||||||
assert async_test_user.first_name == "Updated"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_profile_phone_number(self, client, async_test_user, test_db):
|
async def test_update_profile_phone_number(self, client, async_test_user, test_db):
|
||||||
"""Test updating phone number with validation."""
|
"""Test updating phone number with validation."""
|
||||||
@@ -507,31 +509,38 @@ class TestDeleteUser:
|
|||||||
"""Tests for DELETE /users/{user_id} endpoint."""
|
"""Tests for DELETE /users/{user_id} endpoint."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_user_as_superuser(self, client, async_test_superuser, test_db):
|
async def test_delete_user_as_superuser(self, client, async_test_superuser, async_test_db):
|
||||||
"""Test deleting a user as superuser."""
|
"""Test deleting a user as superuser."""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Create a user to delete
|
# Create a user to delete
|
||||||
user_to_delete = User(
|
async with AsyncTestingSessionLocal() as session:
|
||||||
email="deleteme@example.com",
|
user_to_delete = User(
|
||||||
password_hash="hash",
|
email="deleteme@example.com",
|
||||||
first_name="Delete",
|
password_hash="hash",
|
||||||
is_active=True,
|
first_name="Delete",
|
||||||
is_superuser=False
|
is_active=True,
|
||||||
)
|
is_superuser=False
|
||||||
test_db.add(user_to_delete)
|
)
|
||||||
test_db.commit()
|
session.add(user_to_delete)
|
||||||
test_db.refresh(user_to_delete)
|
await session.commit()
|
||||||
|
await session.refresh(user_to_delete)
|
||||||
|
user_id = user_to_delete.id
|
||||||
|
|
||||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123")
|
||||||
|
|
||||||
response = await client.delete(f"/api/v1/users/{user_to_delete.id}", headers=headers)
|
response = await client.delete(f"/api/v1/users/{user_id}", headers=headers)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
|
|
||||||
# Verify user is soft-deleted (has deleted_at timestamp)
|
# Verify user is soft-deleted (has deleted_at timestamp)
|
||||||
test_db.refresh(user_to_delete)
|
async with AsyncTestingSessionLocal() as session:
|
||||||
assert user_to_delete.deleted_at is not None
|
from sqlalchemy import select
|
||||||
|
result = await session.execute(select(User).where(User.id == user_id))
|
||||||
|
deleted_user = result.scalar_one_or_none()
|
||||||
|
assert deleted_user.deleted_at is not None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_delete_self(self, client, async_test_superuser):
|
async def test_cannot_delete_self(self, client, async_test_superuser):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from datetime import datetime, timezone
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
# Set IS_TEST environment variable BEFORE importing app
|
# Set IS_TEST environment variable BEFORE importing app
|
||||||
# This prevents the scheduler from starting during tests
|
# This prevents the scheduler from starting during tests
|
||||||
@@ -36,10 +36,12 @@ def db_session():
|
|||||||
teardown_test_db(test_engine)
|
teardown_test_db(test_engine)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function") # Define a fixture
|
@pytest_asyncio.fixture(scope="function") # Function scope for isolation
|
||||||
async def async_test_db():
|
async def async_test_db():
|
||||||
"""Fixture provides new testing engine and session for each test run to improve isolation."""
|
"""Fixture provides testing engine and session for each test.
|
||||||
|
|
||||||
|
Each test gets a fresh database for complete isolation.
|
||||||
|
"""
|
||||||
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
|
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
|
||||||
yield test_engine, AsyncTestingSessionLocal
|
yield test_engine, AsyncTestingSessionLocal
|
||||||
await teardown_async_test_db(test_engine)
|
await teardown_async_test_db(test_engine)
|
||||||
@@ -111,7 +113,9 @@ async def client(async_test_db):
|
|||||||
|
|
||||||
app.dependency_overrides[get_async_db] = override_get_async_db
|
app.dependency_overrides[get_async_db] = override_get_async_db
|
||||||
|
|
||||||
async with AsyncClient(app=app, base_url="http://test") as test_client:
|
# Use ASGITransport for httpx >= 0.27
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as test_client:
|
||||||
yield test_client
|
yield test_client
|
||||||
|
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
# tests/services/test_auth_service.py
|
# tests/services/test_auth_service.py
|
||||||
import uuid
|
import uuid
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
|
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
@@ -12,72 +14,100 @@ from app.services.auth_service import AuthService, AuthenticationError
|
|||||||
class TestAuthServiceAuthentication:
|
class TestAuthServiceAuthentication:
|
||||||
"""Tests for AuthService.authenticate_user method"""
|
"""Tests for AuthService.authenticate_user method"""
|
||||||
|
|
||||||
def test_authenticate_valid_user(self, db_session, mock_user):
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_valid_user(self, async_test_db, async_test_user):
|
||||||
"""Test authenticating a user with valid credentials"""
|
"""Test authenticating a user with valid credentials"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Set a known password for the mock user
|
# Set a known password for the mock user
|
||||||
password = "TestPassword123"
|
password = "TestPassword123"
|
||||||
mock_user.password_hash = get_password_hash(password)
|
async with AsyncTestingSessionLocal() as session:
|
||||||
db_session.commit()
|
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
user.password_hash = get_password_hash(password)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
# Authenticate with correct credentials
|
# Authenticate with correct credentials
|
||||||
user = AuthService.authenticate_user(
|
async with AsyncTestingSessionLocal() as session:
|
||||||
db=db_session,
|
auth_user = await AuthService.authenticate_user(
|
||||||
email=mock_user.email,
|
db=session,
|
||||||
password=password
|
email=async_test_user.email,
|
||||||
)
|
|
||||||
|
|
||||||
assert user is not None
|
|
||||||
assert user.id == mock_user.id
|
|
||||||
assert user.email == mock_user.email
|
|
||||||
|
|
||||||
def test_authenticate_nonexistent_user(self, db_session):
|
|
||||||
"""Test authenticating with an email that doesn't exist"""
|
|
||||||
user = AuthService.authenticate_user(
|
|
||||||
db=db_session,
|
|
||||||
email="nonexistent@example.com",
|
|
||||||
password="password"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert user is None
|
|
||||||
|
|
||||||
def test_authenticate_with_wrong_password(self, db_session, mock_user):
|
|
||||||
"""Test authenticating with the wrong password"""
|
|
||||||
# Set a known password for the mock user
|
|
||||||
password = "TestPassword123"
|
|
||||||
mock_user.password_hash = get_password_hash(password)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Authenticate with wrong password
|
|
||||||
user = AuthService.authenticate_user(
|
|
||||||
db=db_session,
|
|
||||||
email=mock_user.email,
|
|
||||||
password="WrongPassword123"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert user is None
|
|
||||||
|
|
||||||
def test_authenticate_inactive_user(self, db_session, mock_user):
|
|
||||||
"""Test authenticating an inactive user"""
|
|
||||||
# Set a known password and make user inactive
|
|
||||||
password = "TestPassword123"
|
|
||||||
mock_user.password_hash = get_password_hash(password)
|
|
||||||
mock_user.is_active = False
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Should raise AuthenticationError
|
|
||||||
with pytest.raises(AuthenticationError):
|
|
||||||
AuthService.authenticate_user(
|
|
||||||
db=db_session,
|
|
||||||
email=mock_user.email,
|
|
||||||
password=password
|
password=password
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert auth_user is not None
|
||||||
|
assert auth_user.id == async_test_user.id
|
||||||
|
assert auth_user.email == async_test_user.email
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_nonexistent_user(self, async_test_db):
|
||||||
|
"""Test authenticating with an email that doesn't exist"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
user = await AuthService.authenticate_user(
|
||||||
|
db=session,
|
||||||
|
email="nonexistent@example.com",
|
||||||
|
password="password"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_with_wrong_password(self, async_test_db, async_test_user):
|
||||||
|
"""Test authenticating with the wrong password"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Set a known password for the mock user
|
||||||
|
password = "TestPassword123"
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
user.password_hash = get_password_hash(password)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Authenticate with wrong password
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
auth_user = await AuthService.authenticate_user(
|
||||||
|
db=session,
|
||||||
|
email=async_test_user.email,
|
||||||
|
password="WrongPassword123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert auth_user is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_inactive_user(self, async_test_db, async_test_user):
|
||||||
|
"""Test authenticating an inactive user"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Set a known password and make user inactive
|
||||||
|
password = "TestPassword123"
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
user.password_hash = get_password_hash(password)
|
||||||
|
user.is_active = False
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Should raise AuthenticationError
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
await AuthService.authenticate_user(
|
||||||
|
db=session,
|
||||||
|
email=async_test_user.email,
|
||||||
|
password=password
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestAuthServiceUserCreation:
|
class TestAuthServiceUserCreation:
|
||||||
"""Tests for AuthService.create_user method"""
|
"""Tests for AuthService.create_user method"""
|
||||||
|
|
||||||
def test_create_new_user(self, db_session):
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_new_user(self, async_test_db):
|
||||||
"""Test creating a new user"""
|
"""Test creating a new user"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
user_data = UserCreate(
|
user_data = UserCreate(
|
||||||
email="newuser@example.com",
|
email="newuser@example.com",
|
||||||
password="TestPassword123",
|
password="TestPassword123",
|
||||||
@@ -86,43 +116,49 @@ class TestAuthServiceUserCreation:
|
|||||||
phone_number="1234567890"
|
phone_number="1234567890"
|
||||||
)
|
)
|
||||||
|
|
||||||
user = AuthService.create_user(db=db_session, user_data=user_data)
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
user = await AuthService.create_user(db=session, user_data=user_data)
|
||||||
|
|
||||||
# Verify user was created with correct data
|
# Verify user was created with correct data
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.email == user_data.email
|
assert user.email == user_data.email
|
||||||
assert user.first_name == user_data.first_name
|
assert user.first_name == user_data.first_name
|
||||||
assert user.last_name == user_data.last_name
|
assert user.last_name == user_data.last_name
|
||||||
assert user.phone_number == user_data.phone_number
|
assert user.phone_number == user_data.phone_number
|
||||||
|
|
||||||
# Verify password was hashed
|
# Verify password was hashed
|
||||||
assert user.password_hash != user_data.password
|
assert user.password_hash != user_data.password
|
||||||
assert verify_password(user_data.password, user.password_hash)
|
assert verify_password(user_data.password, user.password_hash)
|
||||||
|
|
||||||
# Verify default values
|
# Verify default values
|
||||||
assert user.is_active is True
|
assert user.is_active is True
|
||||||
assert user.is_superuser is False
|
assert user.is_superuser is False
|
||||||
|
|
||||||
def test_create_user_with_existing_email(self, db_session, mock_user):
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_user_with_existing_email(self, async_test_db, async_test_user):
|
||||||
"""Test creating a user with an email that already exists"""
|
"""Test creating a user with an email that already exists"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
user_data = UserCreate(
|
user_data = UserCreate(
|
||||||
email=mock_user.email, # Use existing email
|
email=async_test_user.email, # Use existing email
|
||||||
password="TestPassword123",
|
password="TestPassword123",
|
||||||
first_name="Duplicate",
|
first_name="Duplicate",
|
||||||
last_name="User"
|
last_name="User"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should raise AuthenticationError
|
# Should raise AuthenticationError
|
||||||
with pytest.raises(AuthenticationError):
|
async with AsyncTestingSessionLocal() as session:
|
||||||
AuthService.create_user(db=db_session, user_data=user_data)
|
with pytest.raises(AuthenticationError):
|
||||||
|
await AuthService.create_user(db=session, user_data=user_data)
|
||||||
|
|
||||||
|
|
||||||
class TestAuthServiceTokens:
|
class TestAuthServiceTokens:
|
||||||
"""Tests for AuthService token-related methods"""
|
"""Tests for AuthService token-related methods"""
|
||||||
|
|
||||||
def test_create_tokens(self, mock_user):
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_tokens(self, async_test_user):
|
||||||
"""Test creating access and refresh tokens for a user"""
|
"""Test creating access and refresh tokens for a user"""
|
||||||
tokens = AuthService.create_tokens(mock_user)
|
tokens = AuthService.create_tokens(async_test_user)
|
||||||
|
|
||||||
# Verify token structure
|
# Verify token structure
|
||||||
assert isinstance(tokens, Token)
|
assert isinstance(tokens, Token)
|
||||||
@@ -130,50 +166,62 @@ class TestAuthServiceTokens:
|
|||||||
assert tokens.refresh_token is not None
|
assert tokens.refresh_token is not None
|
||||||
assert tokens.token_type == "bearer"
|
assert tokens.token_type == "bearer"
|
||||||
|
|
||||||
# This is a more in-depth test that would decode the tokens to verify claims
|
@pytest.mark.asyncio
|
||||||
# but we'll rely on the auth module tests for token verification
|
async def test_refresh_tokens(self, async_test_db, async_test_user):
|
||||||
|
|
||||||
def test_refresh_tokens(self, db_session, mock_user):
|
|
||||||
"""Test refreshing tokens with a valid refresh token"""
|
"""Test refreshing tokens with a valid refresh token"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Create initial tokens
|
# Create initial tokens
|
||||||
initial_tokens = AuthService.create_tokens(mock_user)
|
initial_tokens = AuthService.create_tokens(async_test_user)
|
||||||
|
|
||||||
# Refresh tokens
|
# Refresh tokens
|
||||||
new_tokens = AuthService.refresh_tokens(
|
async with AsyncTestingSessionLocal() as session:
|
||||||
db=db_session,
|
new_tokens = await AuthService.refresh_tokens(
|
||||||
refresh_token=initial_tokens.refresh_token
|
db=session,
|
||||||
)
|
refresh_token=initial_tokens.refresh_token
|
||||||
|
)
|
||||||
|
|
||||||
# Verify new tokens are different from old ones
|
# Verify new tokens are different from old ones
|
||||||
assert new_tokens.access_token != initial_tokens.access_token
|
assert new_tokens.access_token != initial_tokens.access_token
|
||||||
assert new_tokens.refresh_token != initial_tokens.refresh_token
|
assert new_tokens.refresh_token != initial_tokens.refresh_token
|
||||||
|
|
||||||
def test_refresh_tokens_with_invalid_token(self, db_session):
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_tokens_with_invalid_token(self, async_test_db):
|
||||||
"""Test refreshing tokens with an invalid token"""
|
"""Test refreshing tokens with an invalid token"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Create an invalid token
|
# Create an invalid token
|
||||||
invalid_token = "invalid.token.string"
|
invalid_token = "invalid.token.string"
|
||||||
|
|
||||||
# Should raise TokenInvalidError
|
# Should raise TokenInvalidError
|
||||||
with pytest.raises(TokenInvalidError):
|
async with AsyncTestingSessionLocal() as session:
|
||||||
AuthService.refresh_tokens(
|
with pytest.raises(TokenInvalidError):
|
||||||
db=db_session,
|
await AuthService.refresh_tokens(
|
||||||
refresh_token=invalid_token
|
db=session,
|
||||||
)
|
refresh_token=invalid_token
|
||||||
|
)
|
||||||
|
|
||||||
def test_refresh_tokens_with_access_token(self, db_session, mock_user):
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_tokens_with_access_token(self, async_test_db, async_test_user):
|
||||||
"""Test refreshing tokens with an access token instead of refresh token"""
|
"""Test refreshing tokens with an access token instead of refresh token"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Create tokens
|
# Create tokens
|
||||||
tokens = AuthService.create_tokens(mock_user)
|
tokens = AuthService.create_tokens(async_test_user)
|
||||||
|
|
||||||
# Try to refresh with access token
|
# Try to refresh with access token
|
||||||
with pytest.raises(TokenInvalidError):
|
async with AsyncTestingSessionLocal() as session:
|
||||||
AuthService.refresh_tokens(
|
with pytest.raises(TokenInvalidError):
|
||||||
db=db_session,
|
await AuthService.refresh_tokens(
|
||||||
refresh_token=tokens.access_token
|
db=session,
|
||||||
)
|
refresh_token=tokens.access_token
|
||||||
|
)
|
||||||
|
|
||||||
def test_refresh_tokens_with_nonexistent_user(self, db_session):
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_tokens_with_nonexistent_user(self, async_test_db):
|
||||||
"""Test refreshing tokens for a user that doesn't exist in the database"""
|
"""Test refreshing tokens for a user that doesn't exist in the database"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Create a token for a non-existent user
|
# Create a token for a non-existent user
|
||||||
non_existent_id = str(uuid.uuid4())
|
non_existent_id = str(uuid.uuid4())
|
||||||
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
|
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
|
||||||
@@ -181,72 +229,96 @@ class TestAuthServiceTokens:
|
|||||||
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
|
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
|
||||||
|
|
||||||
# Should raise TokenInvalidError
|
# Should raise TokenInvalidError
|
||||||
with pytest.raises(TokenInvalidError):
|
async with AsyncTestingSessionLocal() as session:
|
||||||
AuthService.refresh_tokens(
|
with pytest.raises(TokenInvalidError):
|
||||||
db=db_session,
|
await AuthService.refresh_tokens(
|
||||||
refresh_token="some.refresh.token"
|
db=session,
|
||||||
)
|
refresh_token="some.refresh.token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestAuthServicePasswordChange:
|
class TestAuthServicePasswordChange:
|
||||||
"""Tests for AuthService.change_password method"""
|
"""Tests for AuthService.change_password method"""
|
||||||
|
|
||||||
def test_change_password(self, db_session, mock_user):
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password(self, async_test_db, async_test_user):
|
||||||
"""Test changing a user's password"""
|
"""Test changing a user's password"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Set a known password for the mock user
|
# Set a known password for the mock user
|
||||||
current_password = "CurrentPassword123"
|
current_password = "CurrentPassword123"
|
||||||
mock_user.password_hash = get_password_hash(current_password)
|
async with AsyncTestingSessionLocal() as session:
|
||||||
db_session.commit()
|
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
user.password_hash = get_password_hash(current_password)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
# Change password
|
# Change password
|
||||||
new_password = "NewPassword456"
|
new_password = "NewPassword456"
|
||||||
result = AuthService.change_password(
|
async with AsyncTestingSessionLocal() as session:
|
||||||
db=db_session,
|
result = await AuthService.change_password(
|
||||||
user_id=mock_user.id,
|
db=session,
|
||||||
current_password=current_password,
|
user_id=async_test_user.id,
|
||||||
new_password=new_password
|
current_password=current_password,
|
||||||
)
|
new_password=new_password
|
||||||
|
)
|
||||||
|
|
||||||
# Verify operation was successful
|
# Verify operation was successful
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Refresh user from DB
|
# Verify password was changed
|
||||||
db_session.refresh(mock_user)
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||||
|
updated_user = result.scalar_one_or_none()
|
||||||
|
|
||||||
# Verify old password no longer works
|
# Verify old password no longer works
|
||||||
assert not verify_password(current_password, mock_user.password_hash)
|
assert not verify_password(current_password, updated_user.password_hash)
|
||||||
|
|
||||||
# Verify new password works
|
# Verify new password works
|
||||||
assert verify_password(new_password, mock_user.password_hash)
|
assert verify_password(new_password, updated_user.password_hash)
|
||||||
|
|
||||||
def test_change_password_wrong_current_password(self, db_session, mock_user):
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_wrong_current_password(self, async_test_db, async_test_user):
|
||||||
"""Test changing password with incorrect current password"""
|
"""Test changing password with incorrect current password"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Set a known password for the mock user
|
# Set a known password for the mock user
|
||||||
current_password = "CurrentPassword123"
|
current_password = "CurrentPassword123"
|
||||||
mock_user.password_hash = get_password_hash(current_password)
|
async with AsyncTestingSessionLocal() as session:
|
||||||
db_session.commit()
|
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
user.password_hash = get_password_hash(current_password)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
# Try to change password with wrong current password
|
# Try to change password with wrong current password
|
||||||
wrong_password = "WrongPassword123"
|
wrong_password = "WrongPassword123"
|
||||||
with pytest.raises(AuthenticationError):
|
async with AsyncTestingSessionLocal() as session:
|
||||||
AuthService.change_password(
|
with pytest.raises(AuthenticationError):
|
||||||
db=db_session,
|
await AuthService.change_password(
|
||||||
user_id=mock_user.id,
|
db=session,
|
||||||
current_password=wrong_password,
|
user_id=async_test_user.id,
|
||||||
new_password="NewPassword456"
|
current_password=wrong_password,
|
||||||
)
|
new_password="NewPassword456"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify password was not changed
|
# Verify password was not changed
|
||||||
assert verify_password(current_password, mock_user.password_hash)
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
assert verify_password(current_password, user.password_hash)
|
||||||
|
|
||||||
def test_change_password_nonexistent_user(self, db_session):
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_nonexistent_user(self, async_test_db):
|
||||||
"""Test changing password for a user that doesn't exist"""
|
"""Test changing password for a user that doesn't exist"""
|
||||||
|
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
non_existent_id = uuid.uuid4()
|
non_existent_id = uuid.uuid4()
|
||||||
|
|
||||||
with pytest.raises(AuthenticationError):
|
async with AsyncTestingSessionLocal() as session:
|
||||||
AuthService.change_password(
|
with pytest.raises(AuthenticationError):
|
||||||
db=db_session,
|
await AuthService.change_password(
|
||||||
user_id=non_existent_id,
|
db=session,
|
||||||
current_password="CurrentPassword123",
|
user_id=non_existent_id,
|
||||||
new_password="NewPassword456"
|
current_password="CurrentPassword123",
|
||||||
)
|
new_password="NewPassword456"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user