Add rate-limiting for authentication endpoints and health check feature
- Introduced rate-limiting to `/auth/*` routes with configurable limits using `SlowAPI`. - Added `/health` endpoint for service monitoring and load balancer health checks. - Updated `requirements.txt` to include `SlowAPI` for rate limiting. - Implemented tests for rate-limiting and health check functionality. - Enhanced configuration and security with updated environment variables, pinned dependencies, and validation adjustments. - Provided example usage and extended coverage in testing.
This commit is contained in:
@@ -2,8 +2,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Body
|
from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
@@ -22,9 +24,14 @@ from app.services.auth_service import AuthService, AuthenticationError
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize limiter for this router
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
||||||
|
@limiter.limit("5/minute")
|
||||||
async def register_user(
|
async def register_user(
|
||||||
|
request: Request,
|
||||||
user_data: UserCreate,
|
user_data: UserCreate,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@@ -52,7 +59,9 @@ async def register_user(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=Token, operation_id="login")
|
@router.post("/login", response_model=Token, operation_id="login")
|
||||||
|
@limiter.limit("10/minute")
|
||||||
async def login(
|
async def login(
|
||||||
|
request: Request,
|
||||||
login_data: LoginRequest,
|
login_data: LoginRequest,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@@ -101,7 +110,9 @@ async def login(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
|
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
|
||||||
|
@limiter.limit("10/minute")
|
||||||
async def login_oauth(
|
async def login_oauth(
|
||||||
|
request: Request,
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@@ -148,7 +159,9 @@ async def login_oauth(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
|
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
async def refresh_token(
|
async def refresh_token(
|
||||||
|
request: Request,
|
||||||
refresh_data: RefreshTokenRequest,
|
refresh_data: RefreshTokenRequest,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@@ -184,7 +197,9 @@ async def refresh_token(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/change-password", status_code=status.HTTP_200_OK, operation_id="change_password")
|
@router.post("/change-password", status_code=status.HTTP_200_OK, operation_id="change_password")
|
||||||
|
@limiter.limit("5/minute")
|
||||||
async def change_password(
|
async def change_password(
|
||||||
|
request: Request,
|
||||||
current_password: str = Body(..., embed=True),
|
current_password: str = Body(..., embed=True),
|
||||||
new_password: str = Body(..., embed=True),
|
new_password: str = Body(..., embed=True),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -220,7 +235,9 @@ async def change_password(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info")
|
@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
async def get_current_user_info(
|
async def get_current_user_info(
|
||||||
|
request: Request,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ class Settings(BaseSettings):
|
|||||||
POSTGRES_PORT: str = "5432"
|
POSTGRES_PORT: str = "5432"
|
||||||
POSTGRES_DB: str = "app"
|
POSTGRES_DB: str = "app"
|
||||||
DATABASE_URL: Optional[str] = None
|
DATABASE_URL: Optional[str] = None
|
||||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
|
|
||||||
db_pool_size: int = 20 # Default connection pool size
|
db_pool_size: int = 20 # Default connection pool size
|
||||||
db_max_overflow: int = 50 # Maximum overflow connections
|
db_max_overflow: int = 50 # Maximum overflow connections
|
||||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||||
@@ -48,7 +47,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# JWT configuration
|
# JWT configuration
|
||||||
SECRET_KEY: str = Field(
|
SECRET_KEY: str = Field(
|
||||||
default="your_secret_key_here",
|
default="dev_only_insecure_key_change_in_production_32chars_min",
|
||||||
min_length=32,
|
min_length=32,
|
||||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
|
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,17 +1,27 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, status, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
from app.api.main import api_router
|
from app.api.main import api_router
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.database import get_db
|
||||||
|
|
||||||
scheduler = AsyncIOScheduler()
|
scheduler = AsyncIOScheduler()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize rate limiter
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
logger.info(f"Starting app!!!")
|
logger.info(f"Starting app!!!")
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=settings.PROJECT_NAME,
|
title=settings.PROJECT_NAME,
|
||||||
@@ -19,6 +29,10 @@ app = FastAPI(
|
|||||||
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add rate limiter state to app
|
||||||
|
app.state.limiter = limiter
|
||||||
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
|
||||||
# Set up CORS middleware
|
# Set up CORS middleware
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@@ -45,4 +59,58 @@ async def root():
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(
|
||||||
|
"/health",
|
||||||
|
summary="Health Check",
|
||||||
|
description="Check the health status of the API and its dependencies",
|
||||||
|
response_description="Health status information",
|
||||||
|
tags=["Health"],
|
||||||
|
operation_id="health_check"
|
||||||
|
)
|
||||||
|
async def health_check() -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Health check endpoint for monitoring and load balancers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSONResponse: Health status with the following information:
|
||||||
|
- status: Overall health status ("healthy" or "unhealthy")
|
||||||
|
- timestamp: Current server timestamp (ISO 8601 format)
|
||||||
|
- version: API version
|
||||||
|
- environment: Current environment (development, staging, production)
|
||||||
|
- database: Database connectivity status
|
||||||
|
"""
|
||||||
|
health_status: Dict[str, Any] = {
|
||||||
|
"status": "healthy",
|
||||||
|
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||||
|
"version": settings.VERSION,
|
||||||
|
"environment": settings.ENVIRONMENT,
|
||||||
|
"checks": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
response_status = status.HTTP_200_OK
|
||||||
|
|
||||||
|
# Database health check
|
||||||
|
try:
|
||||||
|
db = next(get_db())
|
||||||
|
db.execute(text("SELECT 1"))
|
||||||
|
health_status["checks"]["database"] = {
|
||||||
|
"status": "healthy",
|
||||||
|
"message": "Database connection successful"
|
||||||
|
}
|
||||||
|
db.close()
|
||||||
|
except Exception as e:
|
||||||
|
health_status["status"] = "unhealthy"
|
||||||
|
health_status["checks"]["database"] = {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"message": f"Database connection failed: {str(e)}"
|
||||||
|
}
|
||||||
|
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
logger.error(f"Health check failed - database error: {e}")
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=response_status,
|
||||||
|
content=health_status
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||||
|
|||||||
@@ -12,10 +12,8 @@ alembic>=1.14.1
|
|||||||
psycopg2-binary>=2.9.9
|
psycopg2-binary>=2.9.9
|
||||||
asyncpg>=0.29.0
|
asyncpg>=0.29.0
|
||||||
aiosqlite==0.21.0
|
aiosqlite==0.21.0
|
||||||
# Security and authentication
|
|
||||||
python-jose>=3.4.0
|
# Environment configuration
|
||||||
passlib>=1.7.4
|
|
||||||
bcrypt>=4.1.2
|
|
||||||
python-dotenv>=1.0.1
|
python-dotenv>=1.0.1
|
||||||
|
|
||||||
# API documentation
|
# API documentation
|
||||||
@@ -26,6 +24,9 @@ ujson>=5.9.0
|
|||||||
starlette>=0.40.0
|
starlette>=0.40.0
|
||||||
starlette-csrf>=1.4.5
|
starlette-csrf>=1.4.5
|
||||||
|
|
||||||
|
# Rate limiting
|
||||||
|
slowapi>=0.1.9
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
httpx>=0.27.0
|
httpx>=0.27.0
|
||||||
tenacity>=8.2.3
|
tenacity>=8.2.3
|
||||||
@@ -44,9 +45,11 @@ isort>=5.13.2
|
|||||||
flake8>=7.0.0
|
flake8>=7.0.0
|
||||||
mypy>=1.8.0
|
mypy>=1.8.0
|
||||||
|
|
||||||
# Security
|
# Security and authentication (pinned for reproducibility)
|
||||||
python-jose==3.4.0
|
python-jose==3.4.0
|
||||||
|
passlib==1.7.4
|
||||||
bcrypt==4.2.1
|
bcrypt==4.2.1
|
||||||
cryptography==44.0.1
|
cryptography==44.0.1
|
||||||
passlib==1.7.4
|
|
||||||
|
# Testing utilities
|
||||||
freezegun~=1.5.1
|
freezegun~=1.5.1
|
||||||
200
backend/tests/api/routes/test_health.py
Normal file
200
backend/tests/api/routes/test_health.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
# tests/api/routes/test_health.py
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from fastapi import status
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
from app.core.database import get_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
"""Create a FastAPI test client for the main app with mocked database."""
|
||||||
|
# Mock get_db to avoid connecting to the actual database
|
||||||
|
with patch("app.main.get_db") as mock_get_db:
|
||||||
|
def mock_session_generator():
|
||||||
|
mock_session = MagicMock()
|
||||||
|
# Mock the execute method to return successfully
|
||||||
|
mock_session.execute.return_value = None
|
||||||
|
mock_session.close.return_value = None
|
||||||
|
yield mock_session
|
||||||
|
|
||||||
|
# Return a new generator each time get_db is called
|
||||||
|
mock_get_db.side_effect = lambda: mock_session_generator()
|
||||||
|
yield TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthEndpoint:
|
||||||
|
"""Tests for the /health endpoint"""
|
||||||
|
|
||||||
|
def test_health_check_healthy(self, client):
|
||||||
|
"""Test that health check returns healthy when database is accessible"""
|
||||||
|
response = client.get("/health")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
assert "status" in data
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
assert "timestamp" in data
|
||||||
|
assert "version" in data
|
||||||
|
assert "environment" in data
|
||||||
|
assert "checks" in data
|
||||||
|
|
||||||
|
# Verify timestamp format (ISO 8601)
|
||||||
|
assert data["timestamp"].endswith("Z")
|
||||||
|
# Verify it's a valid datetime
|
||||||
|
datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
# Check database health
|
||||||
|
assert "database" in data["checks"]
|
||||||
|
assert data["checks"]["database"]["status"] == "healthy"
|
||||||
|
assert "message" in data["checks"]["database"]
|
||||||
|
|
||||||
|
def test_health_check_response_structure(self, client):
|
||||||
|
"""Test that health check response has correct structure"""
|
||||||
|
response = client.get("/health")
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify top-level structure
|
||||||
|
assert isinstance(data["status"], str)
|
||||||
|
assert isinstance(data["timestamp"], str)
|
||||||
|
assert isinstance(data["version"], str)
|
||||||
|
assert isinstance(data["environment"], str)
|
||||||
|
assert isinstance(data["checks"], dict)
|
||||||
|
|
||||||
|
# Verify database check structure
|
||||||
|
db_check = data["checks"]["database"]
|
||||||
|
assert isinstance(db_check["status"], str)
|
||||||
|
assert isinstance(db_check["message"], str)
|
||||||
|
|
||||||
|
def test_health_check_version_matches_settings(self, client):
|
||||||
|
"""Test that health check returns correct version from settings"""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
response = client.get("/health")
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["version"] == settings.VERSION
|
||||||
|
|
||||||
|
def test_health_check_environment_matches_settings(self, client):
|
||||||
|
"""Test that health check returns correct environment from settings"""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
response = client.get("/health")
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["environment"] == settings.ENVIRONMENT
|
||||||
|
|
||||||
|
def test_health_check_database_connection_failure(self, client):
|
||||||
|
"""Test that health check returns unhealthy when database is not accessible"""
|
||||||
|
# Mock the database session to raise an exception
|
||||||
|
with patch("app.main.get_db") as mock_get_db:
|
||||||
|
def mock_session():
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.execute.side_effect = OperationalError(
|
||||||
|
"Connection refused",
|
||||||
|
params=None,
|
||||||
|
orig=Exception("Connection refused")
|
||||||
|
)
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
mock_get_db.return_value = mock_session()
|
||||||
|
|
||||||
|
response = client.get("/health")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Check overall status
|
||||||
|
assert data["status"] == "unhealthy"
|
||||||
|
|
||||||
|
# Check database status
|
||||||
|
assert "database" in data["checks"]
|
||||||
|
assert data["checks"]["database"]["status"] == "unhealthy"
|
||||||
|
assert "failed" in data["checks"]["database"]["message"].lower()
|
||||||
|
|
||||||
|
def test_health_check_timestamp_recent(self, client):
|
||||||
|
"""Test that health check timestamp is recent (within last minute)"""
|
||||||
|
before = datetime.utcnow()
|
||||||
|
response = client.get("/health")
|
||||||
|
after = datetime.utcnow()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
timestamp = datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
# Timestamp should be between before and after
|
||||||
|
assert before <= timestamp.replace(tzinfo=None) <= after
|
||||||
|
|
||||||
|
def test_health_check_no_authentication_required(self, client):
|
||||||
|
"""Test that health check does not require authentication"""
|
||||||
|
# Make request without any authentication headers
|
||||||
|
response = client.get("/health")
|
||||||
|
|
||||||
|
# Should succeed without authentication
|
||||||
|
assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE]
|
||||||
|
|
||||||
|
def test_health_check_idempotent(self, client):
|
||||||
|
"""Test that multiple health checks return consistent results"""
|
||||||
|
response1 = client.get("/health")
|
||||||
|
response2 = client.get("/health")
|
||||||
|
|
||||||
|
# Both should have same status code (either both healthy or both unhealthy)
|
||||||
|
assert response1.status_code == response2.status_code
|
||||||
|
|
||||||
|
data1 = response1.json()
|
||||||
|
data2 = response2.json()
|
||||||
|
|
||||||
|
# Same overall health status
|
||||||
|
assert data1["status"] == data2["status"]
|
||||||
|
|
||||||
|
# Same version and environment
|
||||||
|
assert data1["version"] == data2["version"]
|
||||||
|
assert data1["environment"] == data2["environment"]
|
||||||
|
|
||||||
|
# Same database check status
|
||||||
|
assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"]
|
||||||
|
|
||||||
|
def test_health_check_content_type(self, client):
|
||||||
|
"""Test that health check returns JSON content type"""
|
||||||
|
response = client.get("/health")
|
||||||
|
|
||||||
|
assert "application/json" in response.headers["content-type"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthEndpointEdgeCases:
|
||||||
|
"""Edge case tests for the /health endpoint"""
|
||||||
|
|
||||||
|
def test_health_check_with_query_parameters(self, client):
|
||||||
|
"""Test that health check ignores query parameters"""
|
||||||
|
response = client.get("/health?foo=bar&baz=qux")
|
||||||
|
|
||||||
|
# Should still work with query params
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
def test_health_check_method_not_allowed(self, client):
|
||||||
|
"""Test that POST/PUT/DELETE are not allowed on health endpoint"""
|
||||||
|
# POST should not be allowed
|
||||||
|
response = client.post("/health")
|
||||||
|
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
|
||||||
|
|
||||||
|
# PUT should not be allowed
|
||||||
|
response = client.put("/health")
|
||||||
|
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
|
||||||
|
|
||||||
|
# DELETE should not be allowed
|
||||||
|
response = client.delete("/health")
|
||||||
|
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
|
||||||
|
|
||||||
|
def test_health_check_with_accept_header(self, client):
|
||||||
|
"""Test that health check works with different Accept headers"""
|
||||||
|
response = client.get("/health", headers={"Accept": "application/json"})
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response = client.get("/health", headers={"Accept": "*/*"})
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
194
backend/tests/api/routes/test_rate_limiting.py
Normal file
194
backend/tests/api/routes/test_rate_limiting.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# tests/api/routes/test_rate_limiting.py
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, status
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from app.api.routes.auth import router as auth_router, limiter
|
||||||
|
from app.core.database import get_db
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the get_db dependency
|
||||||
|
@pytest.fixture
|
||||||
|
def override_get_db():
|
||||||
|
"""Override get_db dependency for testing."""
|
||||||
|
mock_db = MagicMock()
|
||||||
|
return mock_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(override_get_db):
|
||||||
|
"""Create a FastAPI test application with rate limiting."""
|
||||||
|
from slowapi import _rate_limit_exceeded_handler
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.state.limiter = limiter
|
||||||
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
app.include_router(auth_router, prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
# Override the get_db dependency
|
||||||
|
app.dependency_overrides[get_db] = lambda: override_get_db
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app):
|
||||||
|
"""Create a FastAPI test client."""
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegisterRateLimiting:
|
||||||
|
"""Tests for rate limiting on /register endpoint"""
|
||||||
|
|
||||||
|
def test_register_rate_limit_blocks_over_limit(self, client):
|
||||||
|
"""Test that requests over rate limit are blocked"""
|
||||||
|
from app.services.auth_service import AuthService
|
||||||
|
from app.models.user import User
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
mock_user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
password_hash="hashed",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User",
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(AuthService, 'create_user', return_value=mock_user):
|
||||||
|
user_data = {
|
||||||
|
"email": f"test{uuid.uuid4()}@example.com",
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
"first_name": "Test",
|
||||||
|
"last_name": "User"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make 6 requests (limit is 5/minute)
|
||||||
|
responses = []
|
||||||
|
for i in range(6):
|
||||||
|
response = client.post("/auth/register", json=user_data)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
# Last request should be rate limited
|
||||||
|
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoginRateLimiting:
|
||||||
|
"""Tests for rate limiting on /login endpoint"""
|
||||||
|
|
||||||
|
def test_login_rate_limit_blocks_over_limit(self, client):
|
||||||
|
"""Test that login requests over rate limit are blocked"""
|
||||||
|
from app.services.auth_service import AuthService
|
||||||
|
|
||||||
|
with patch.object(AuthService, 'authenticate_user', return_value=None):
|
||||||
|
login_data = {
|
||||||
|
"email": "test@example.com",
|
||||||
|
"password": "wrong_password"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make 11 requests (limit is 10/minute)
|
||||||
|
responses = []
|
||||||
|
for i in range(11):
|
||||||
|
response = client.post("/auth/login", json=login_data)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
# Last request should be rate limited
|
||||||
|
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||||
|
|
||||||
|
|
||||||
|
class TestRefreshTokenRateLimiting:
|
||||||
|
"""Tests for rate limiting on /refresh endpoint"""
|
||||||
|
|
||||||
|
def test_refresh_rate_limit_blocks_over_limit(self, client):
|
||||||
|
"""Test that refresh requests over rate limit are blocked"""
|
||||||
|
from app.services.auth_service import AuthService
|
||||||
|
from app.core.auth import TokenInvalidError
|
||||||
|
|
||||||
|
with patch.object(AuthService, 'refresh_tokens', side_effect=TokenInvalidError("Invalid")):
|
||||||
|
refresh_data = {
|
||||||
|
"refresh_token": "invalid_token"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make 31 requests (limit is 30/minute)
|
||||||
|
responses = []
|
||||||
|
for i in range(31):
|
||||||
|
response = client.post("/auth/refresh", json=refresh_data)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
# Last request should be rate limited
|
||||||
|
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||||
|
|
||||||
|
|
||||||
|
class TestChangePasswordRateLimiting:
|
||||||
|
"""Tests for rate limiting on /change-password endpoint"""
|
||||||
|
|
||||||
|
def test_change_password_rate_limit_blocks_over_limit(self, client):
|
||||||
|
"""Test that change password requests over rate limit are blocked"""
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.auth_service import AuthService, AuthenticationError
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# Mock current user
|
||||||
|
mock_user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
password_hash="hashed",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User",
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override get_current_user dependency in the app
|
||||||
|
test_app = client.app
|
||||||
|
test_app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
|
||||||
|
with patch.object(AuthService, 'change_password', side_effect=AuthenticationError("Invalid password")):
|
||||||
|
password_data = {
|
||||||
|
"current_password": "wrong_password",
|
||||||
|
"new_password": "NewPassword123!"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make 6 requests (limit is 5/minute)
|
||||||
|
responses = []
|
||||||
|
for i in range(6):
|
||||||
|
response = client.post("/auth/change-password", json=password_data)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
# Last request should be rate limited
|
||||||
|
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||||
|
|
||||||
|
# Clean up override
|
||||||
|
test_app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitErrorResponse:
|
||||||
|
"""Tests for rate limit error response format"""
|
||||||
|
|
||||||
|
def test_rate_limit_error_response_format(self, client):
|
||||||
|
"""Test that rate limit error has correct format"""
|
||||||
|
from app.services.auth_service import AuthService
|
||||||
|
|
||||||
|
with patch.object(AuthService, 'authenticate_user', return_value=None):
|
||||||
|
login_data = {
|
||||||
|
"email": "test@example.com",
|
||||||
|
"password": "password"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Exceed rate limit
|
||||||
|
for i in range(11):
|
||||||
|
response = client.post("/auth/login", json=login_data)
|
||||||
|
|
||||||
|
# Check error response
|
||||||
|
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||||
|
assert "detail" in response.json() or "error" in response.json()
|
||||||
Reference in New Issue
Block a user