Add rate-limiting for authentication endpoints and health check feature

- Introduced rate-limiting to `/auth/*` routes with configurable limits using `SlowAPI`.
- Added `/health` endpoint for service monitoring and load balancer health checks.
- Updated `requirements.txt` to include `SlowAPI` for rate limiting.
- Implemented tests for rate-limiting and health check functionality.
- Enhanced configuration and security with updated environment variables, pinned dependencies, and validation adjustments.
- Provided example usage and extended coverage in testing.
This commit is contained in:
Felipe Cardoso
2025-10-29 23:59:29 +01:00
parent f163ffbb83
commit 5bed14b6b0
6 changed files with 492 additions and 11 deletions

View File

@@ -2,8 +2,10 @@
import logging
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status, Body
from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.orm import Session
from app.api.dependencies.auth import get_current_user
@@ -22,9 +24,14 @@ from app.services.auth_service import AuthService, AuthenticationError
router = APIRouter()
logger = logging.getLogger(__name__)
# Initialize limiter for this router
limiter = Limiter(key_func=get_remote_address)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
@limiter.limit("5/minute")
async def register_user(
request: Request,
user_data: UserCreate,
db: Session = Depends(get_db)
) -> Any:
@@ -52,7 +59,9 @@ async def register_user(
@router.post("/login", response_model=Token, operation_id="login")
@limiter.limit("10/minute")
async def login(
request: Request,
login_data: LoginRequest,
db: Session = Depends(get_db)
) -> Any:
@@ -101,7 +110,9 @@ async def login(
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
@limiter.limit("10/minute")
async def login_oauth(
request: Request,
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
) -> Any:
@@ -148,7 +159,9 @@ async def login_oauth(
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
@limiter.limit("30/minute")
async def refresh_token(
request: Request,
refresh_data: RefreshTokenRequest,
db: Session = Depends(get_db)
) -> Any:
@@ -184,7 +197,9 @@ async def refresh_token(
@router.post("/change-password", status_code=status.HTTP_200_OK, operation_id="change_password")
@limiter.limit("5/minute")
async def change_password(
request: Request,
current_password: str = Body(..., embed=True),
new_password: str = Body(..., embed=True),
current_user: User = Depends(get_current_user),
@@ -220,7 +235,9 @@ async def change_password(
@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info")
@limiter.limit("60/minute")
async def get_current_user_info(
request: Request,
current_user: User = Depends(get_current_user)
) -> Any:
"""

View File

@@ -22,7 +22,6 @@ class Settings(BaseSettings):
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "app"
DATABASE_URL: Optional[str] = None
REFRESH_TOKEN_EXPIRE_DAYS: int = 60
db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection
@@ -48,7 +47,7 @@ class Settings(BaseSettings):
# JWT configuration
SECRET_KEY: str = Field(
default="your_secret_key_here",
default="dev_only_insecure_key_change_in_production_32chars_min",
min_length=32,
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
)

View File

@@ -1,17 +1,27 @@
import logging
from datetime import datetime
from typing import Dict, Any
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import FastAPI
from fastapi import FastAPI, status, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.responses import HTMLResponse, JSONResponse
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from sqlalchemy import text
from app.api.main import api_router
from app.core.config import settings
from app.core.database import get_db
scheduler = AsyncIOScheduler()
logger = logging.getLogger(__name__)
# Initialize rate limiter
limiter = Limiter(key_func=get_remote_address)
logger.info(f"Starting app!!!")
app = FastAPI(
title=settings.PROJECT_NAME,
@@ -19,6 +29,10 @@ app = FastAPI(
openapi_url=f"{settings.API_V1_STR}/openapi.json"
)
# Add rate limiter state to app
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Set up CORS middleware
app.add_middleware(
CORSMiddleware,
@@ -45,4 +59,58 @@ async def root():
"""
@app.get(
"/health",
summary="Health Check",
description="Check the health status of the API and its dependencies",
response_description="Health status information",
tags=["Health"],
operation_id="health_check"
)
async def health_check() -> JSONResponse:
"""
Health check endpoint for monitoring and load balancers.
Returns:
JSONResponse: Health status with the following information:
- status: Overall health status ("healthy" or "unhealthy")
- timestamp: Current server timestamp (ISO 8601 format)
- version: API version
- environment: Current environment (development, staging, production)
- database: Database connectivity status
"""
health_status: Dict[str, Any] = {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat() + "Z",
"version": settings.VERSION,
"environment": settings.ENVIRONMENT,
"checks": {}
}
response_status = status.HTTP_200_OK
# Database health check
try:
db = next(get_db())
db.execute(text("SELECT 1"))
health_status["checks"]["database"] = {
"status": "healthy",
"message": "Database connection successful"
}
db.close()
except Exception as e:
health_status["status"] = "unhealthy"
health_status["checks"]["database"] = {
"status": "unhealthy",
"message": f"Database connection failed: {str(e)}"
}
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
logger.error(f"Health check failed - database error: {e}")
return JSONResponse(
status_code=response_status,
content=health_status
)
app.include_router(api_router, prefix=settings.API_V1_STR)

View File

@@ -12,10 +12,8 @@ alembic>=1.14.1
psycopg2-binary>=2.9.9
asyncpg>=0.29.0
aiosqlite==0.21.0
# Security and authentication
python-jose>=3.4.0
passlib>=1.7.4
bcrypt>=4.1.2
# Environment configuration
python-dotenv>=1.0.1
# API documentation
@@ -26,6 +24,9 @@ ujson>=5.9.0
starlette>=0.40.0
starlette-csrf>=1.4.5
# Rate limiting
slowapi>=0.1.9
# Utilities
httpx>=0.27.0
tenacity>=8.2.3
@@ -44,9 +45,11 @@ isort>=5.13.2
flake8>=7.0.0
mypy>=1.8.0
# Security
# Security and authentication (pinned for reproducibility)
python-jose==3.4.0
passlib==1.7.4
bcrypt==4.2.1
cryptography==44.0.1
passlib==1.7.4
# Testing utilities
freezegun~=1.5.1

View File

@@ -0,0 +1,200 @@
# tests/api/routes/test_health.py
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi import status
from fastapi.testclient import TestClient
from datetime import datetime
from sqlalchemy.exc import OperationalError
from app.main import app
from app.core.database import get_db
@pytest.fixture
def client():
"""Create a FastAPI test client for the main app with mocked database."""
# Mock get_db to avoid connecting to the actual database
with patch("app.main.get_db") as mock_get_db:
def mock_session_generator():
mock_session = MagicMock()
# Mock the execute method to return successfully
mock_session.execute.return_value = None
mock_session.close.return_value = None
yield mock_session
# Return a new generator each time get_db is called
mock_get_db.side_effect = lambda: mock_session_generator()
yield TestClient(app)
class TestHealthEndpoint:
"""Tests for the /health endpoint"""
def test_health_check_healthy(self, client):
"""Test that health check returns healthy when database is accessible"""
response = client.get("/health")
assert response.status_code == status.HTTP_200_OK
data = response.json()
# Check required fields
assert "status" in data
assert data["status"] == "healthy"
assert "timestamp" in data
assert "version" in data
assert "environment" in data
assert "checks" in data
# Verify timestamp format (ISO 8601)
assert data["timestamp"].endswith("Z")
# Verify it's a valid datetime
datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00"))
# Check database health
assert "database" in data["checks"]
assert data["checks"]["database"]["status"] == "healthy"
assert "message" in data["checks"]["database"]
def test_health_check_response_structure(self, client):
"""Test that health check response has correct structure"""
response = client.get("/health")
data = response.json()
# Verify top-level structure
assert isinstance(data["status"], str)
assert isinstance(data["timestamp"], str)
assert isinstance(data["version"], str)
assert isinstance(data["environment"], str)
assert isinstance(data["checks"], dict)
# Verify database check structure
db_check = data["checks"]["database"]
assert isinstance(db_check["status"], str)
assert isinstance(db_check["message"], str)
def test_health_check_version_matches_settings(self, client):
"""Test that health check returns correct version from settings"""
from app.core.config import settings
response = client.get("/health")
data = response.json()
assert data["version"] == settings.VERSION
def test_health_check_environment_matches_settings(self, client):
"""Test that health check returns correct environment from settings"""
from app.core.config import settings
response = client.get("/health")
data = response.json()
assert data["environment"] == settings.ENVIRONMENT
def test_health_check_database_connection_failure(self, client):
"""Test that health check returns unhealthy when database is not accessible"""
# Mock the database session to raise an exception
with patch("app.main.get_db") as mock_get_db:
def mock_session():
from unittest.mock import MagicMock
mock = MagicMock()
mock.execute.side_effect = OperationalError(
"Connection refused",
params=None,
orig=Exception("Connection refused")
)
yield mock
mock_get_db.return_value = mock_session()
response = client.get("/health")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
data = response.json()
# Check overall status
assert data["status"] == "unhealthy"
# Check database status
assert "database" in data["checks"]
assert data["checks"]["database"]["status"] == "unhealthy"
assert "failed" in data["checks"]["database"]["message"].lower()
def test_health_check_timestamp_recent(self, client):
"""Test that health check timestamp is recent (within last minute)"""
before = datetime.utcnow()
response = client.get("/health")
after = datetime.utcnow()
data = response.json()
timestamp = datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00"))
# Timestamp should be between before and after
assert before <= timestamp.replace(tzinfo=None) <= after
def test_health_check_no_authentication_required(self, client):
"""Test that health check does not require authentication"""
# Make request without any authentication headers
response = client.get("/health")
# Should succeed without authentication
assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE]
def test_health_check_idempotent(self, client):
"""Test that multiple health checks return consistent results"""
response1 = client.get("/health")
response2 = client.get("/health")
# Both should have same status code (either both healthy or both unhealthy)
assert response1.status_code == response2.status_code
data1 = response1.json()
data2 = response2.json()
# Same overall health status
assert data1["status"] == data2["status"]
# Same version and environment
assert data1["version"] == data2["version"]
assert data1["environment"] == data2["environment"]
# Same database check status
assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"]
def test_health_check_content_type(self, client):
"""Test that health check returns JSON content type"""
response = client.get("/health")
assert "application/json" in response.headers["content-type"]
class TestHealthEndpointEdgeCases:
"""Edge case tests for the /health endpoint"""
def test_health_check_with_query_parameters(self, client):
"""Test that health check ignores query parameters"""
response = client.get("/health?foo=bar&baz=qux")
# Should still work with query params
assert response.status_code == status.HTTP_200_OK
def test_health_check_method_not_allowed(self, client):
"""Test that POST/PUT/DELETE are not allowed on health endpoint"""
# POST should not be allowed
response = client.post("/health")
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
# PUT should not be allowed
response = client.put("/health")
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
# DELETE should not be allowed
response = client.delete("/health")
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_health_check_with_accept_header(self, client):
"""Test that health check works with different Accept headers"""
response = client.get("/health", headers={"Accept": "application/json"})
assert response.status_code == status.HTTP_200_OK
response = client.get("/health", headers={"Accept": "*/*"})
assert response.status_code == status.HTTP_200_OK

View File

@@ -0,0 +1,194 @@
# tests/api/routes/test_rate_limiting.py
import pytest
from fastapi import FastAPI, status
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock
from app.api.routes.auth import router as auth_router, limiter
from app.core.database import get_db
# Mock the get_db dependency
@pytest.fixture
def override_get_db():
"""Override get_db dependency for testing."""
mock_db = MagicMock()
return mock_db
@pytest.fixture
def app(override_get_db):
"""Create a FastAPI test application with rate limiting."""
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.include_router(auth_router, prefix="/auth", tags=["auth"])
# Override the get_db dependency
app.dependency_overrides[get_db] = lambda: override_get_db
return app
@pytest.fixture
def client(app):
"""Create a FastAPI test client."""
return TestClient(app)
class TestRegisterRateLimiting:
"""Tests for rate limiting on /register endpoint"""
def test_register_rate_limit_blocks_over_limit(self, client):
"""Test that requests over rate limit are blocked"""
from app.services.auth_service import AuthService
from app.models.user import User
from datetime import datetime, timezone
import uuid
mock_user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash="hashed",
first_name="Test",
last_name="User",
is_active=True,
is_superuser=False,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
with patch.object(AuthService, 'create_user', return_value=mock_user):
user_data = {
"email": f"test{uuid.uuid4()}@example.com",
"password": "TestPassword123!",
"first_name": "Test",
"last_name": "User"
}
# Make 6 requests (limit is 5/minute)
responses = []
for i in range(6):
response = client.post("/auth/register", json=user_data)
responses.append(response)
# Last request should be rate limited
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
class TestLoginRateLimiting:
"""Tests for rate limiting on /login endpoint"""
def test_login_rate_limit_blocks_over_limit(self, client):
"""Test that login requests over rate limit are blocked"""
from app.services.auth_service import AuthService
with patch.object(AuthService, 'authenticate_user', return_value=None):
login_data = {
"email": "test@example.com",
"password": "wrong_password"
}
# Make 11 requests (limit is 10/minute)
responses = []
for i in range(11):
response = client.post("/auth/login", json=login_data)
responses.append(response)
# Last request should be rate limited
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
class TestRefreshTokenRateLimiting:
"""Tests for rate limiting on /refresh endpoint"""
def test_refresh_rate_limit_blocks_over_limit(self, client):
"""Test that refresh requests over rate limit are blocked"""
from app.services.auth_service import AuthService
from app.core.auth import TokenInvalidError
with patch.object(AuthService, 'refresh_tokens', side_effect=TokenInvalidError("Invalid")):
refresh_data = {
"refresh_token": "invalid_token"
}
# Make 31 requests (limit is 30/minute)
responses = []
for i in range(31):
response = client.post("/auth/refresh", json=refresh_data)
responses.append(response)
# Last request should be rate limited
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
class TestChangePasswordRateLimiting:
"""Tests for rate limiting on /change-password endpoint"""
def test_change_password_rate_limit_blocks_over_limit(self, client):
"""Test that change password requests over rate limit are blocked"""
from app.api.dependencies.auth import get_current_user
from app.models.user import User
from app.services.auth_service import AuthService, AuthenticationError
from datetime import datetime, timezone
import uuid
# Mock current user
mock_user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash="hashed",
first_name="Test",
last_name="User",
is_active=True,
is_superuser=False,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
# Override get_current_user dependency in the app
test_app = client.app
test_app.dependency_overrides[get_current_user] = lambda: mock_user
with patch.object(AuthService, 'change_password', side_effect=AuthenticationError("Invalid password")):
password_data = {
"current_password": "wrong_password",
"new_password": "NewPassword123!"
}
# Make 6 requests (limit is 5/minute)
responses = []
for i in range(6):
response = client.post("/auth/change-password", json=password_data)
responses.append(response)
# Last request should be rate limited
assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS
# Clean up override
test_app.dependency_overrides.clear()
class TestRateLimitErrorResponse:
"""Tests for rate limit error response format"""
def test_rate_limit_error_response_format(self, client):
"""Test that rate limit error has correct format"""
from app.services.auth_service import AuthService
with patch.object(AuthService, 'authenticate_user', return_value=None):
login_data = {
"email": "test@example.com",
"password": "password"
}
# Exceed rate limit
for i in range(11):
response = client.post("/auth/login", json=login_data)
# Check error response
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert "detail" in response.json() or "error" in response.json()