diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index e387081..d3b4ba6 100644 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -2,8 +2,10 @@ import logging from typing import Any -from fastapi import APIRouter, Depends, HTTPException, status, Body +from fastapi import APIRouter, Depends, HTTPException, status, Body, Request from fastapi.security import OAuth2PasswordRequestForm +from slowapi import Limiter +from slowapi.util import get_remote_address from sqlalchemy.orm import Session from app.api.dependencies.auth import get_current_user @@ -22,9 +24,14 @@ from app.services.auth_service import AuthService, AuthenticationError router = APIRouter() logger = logging.getLogger(__name__) +# Initialize limiter for this router +limiter = Limiter(key_func=get_remote_address) + @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register") +@limiter.limit("5/minute") async def register_user( + request: Request, user_data: UserCreate, db: Session = Depends(get_db) ) -> Any: @@ -52,7 +59,9 @@ async def register_user( @router.post("/login", response_model=Token, operation_id="login") +@limiter.limit("10/minute") async def login( + request: Request, login_data: LoginRequest, db: Session = Depends(get_db) ) -> Any: @@ -101,7 +110,9 @@ async def login( @router.post("/login/oauth", response_model=Token, operation_id='login_oauth') +@limiter.limit("10/minute") async def login_oauth( + request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) ) -> Any: @@ -148,7 +159,9 @@ async def login_oauth( @router.post("/refresh", response_model=Token, operation_id="refresh_token") +@limiter.limit("30/minute") async def refresh_token( + request: Request, refresh_data: RefreshTokenRequest, db: Session = Depends(get_db) ) -> Any: @@ -184,7 +197,9 @@ async def refresh_token( @router.post("/change-password", status_code=status.HTTP_200_OK, operation_id="change_password") +@limiter.limit("5/minute") async def change_password( + request: Request, current_password: str = Body(..., embed=True), new_password: str = Body(..., embed=True), current_user: User = Depends(get_current_user), @@ -220,7 +235,9 @@ async def change_password( @router.get("/me", response_model=UserResponse, operation_id="get_current_user_info") +@limiter.limit("60/minute") async def get_current_user_info( + request: Request, current_user: User = Depends(get_current_user) ) -> Any: """ diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 261939f..06bb210 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -22,7 +22,6 @@ class Settings(BaseSettings): POSTGRES_PORT: str = "5432" POSTGRES_DB: str = "app" DATABASE_URL: Optional[str] = None - REFRESH_TOKEN_EXPIRE_DAYS: int = 60 db_pool_size: int = 20 # Default connection pool size db_max_overflow: int = 50 # Maximum overflow connections db_pool_timeout: int = 30 # Seconds to wait for a connection @@ -48,7 +47,7 @@ class Settings(BaseSettings): # JWT configuration SECRET_KEY: str = Field( - default="your_secret_key_here", + default="dev_only_insecure_key_change_in_production_32chars_min", min_length=32, description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'" ) diff --git a/backend/app/main.py b/backend/app/main.py index c1ad77f..92f75e8 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,17 +1,27 @@ import logging +from datetime import datetime +from typing import Dict, Any from apscheduler.schedulers.asyncio import AsyncIOScheduler -from fastapi import FastAPI +from fastapi import FastAPI, status, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import HTMLResponse +from fastapi.responses import HTMLResponse, JSONResponse +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded +from sqlalchemy import text from app.api.main import api_router from app.core.config import settings +from app.core.database import get_db scheduler = AsyncIOScheduler() logger = logging.getLogger(__name__) +# Initialize rate limiter +limiter = Limiter(key_func=get_remote_address) + logger.info(f"Starting app!!!") app = FastAPI( title=settings.PROJECT_NAME, @@ -19,6 +29,10 @@ app = FastAPI( openapi_url=f"{settings.API_V1_STR}/openapi.json" ) +# Add rate limiter state to app +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + # Set up CORS middleware app.add_middleware( CORSMiddleware, @@ -45,4 +59,58 @@ async def root(): """ +@app.get( + "/health", + summary="Health Check", + description="Check the health status of the API and its dependencies", + response_description="Health status information", + tags=["Health"], + operation_id="health_check" +) +async def health_check() -> JSONResponse: + """ + Health check endpoint for monitoring and load balancers. + + Returns: + JSONResponse: Health status with the following information: + - status: Overall health status ("healthy" or "unhealthy") + - timestamp: Current server timestamp (ISO 8601 format) + - version: API version + - environment: Current environment (development, staging, production) + - database: Database connectivity status + """ + health_status: Dict[str, Any] = { + "status": "healthy", + "timestamp": datetime.utcnow().isoformat() + "Z", + "version": settings.VERSION, + "environment": settings.ENVIRONMENT, + "checks": {} + } + + response_status = status.HTTP_200_OK + + # Database health check + try: + db = next(get_db()) + db.execute(text("SELECT 1")) + health_status["checks"]["database"] = { + "status": "healthy", + "message": "Database connection successful" + } + db.close() + except Exception as e: + health_status["status"] = "unhealthy" + health_status["checks"]["database"] = { + "status": "unhealthy", + "message": f"Database connection failed: {str(e)}" + } + response_status = status.HTTP_503_SERVICE_UNAVAILABLE + logger.error(f"Health check failed - database error: {e}") + + return JSONResponse( + status_code=response_status, + content=health_status + ) + + app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/backend/requirements.txt b/backend/requirements.txt index ecc4682..cddd75b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -12,10 +12,8 @@ alembic>=1.14.1 psycopg2-binary>=2.9.9 asyncpg>=0.29.0 aiosqlite==0.21.0 -# Security and authentication -python-jose>=3.4.0 -passlib>=1.7.4 -bcrypt>=4.1.2 + +# Environment configuration python-dotenv>=1.0.1 # API documentation @@ -26,6 +24,9 @@ ujson>=5.9.0 starlette>=0.40.0 starlette-csrf>=1.4.5 +# Rate limiting +slowapi>=0.1.9 + # Utilities httpx>=0.27.0 tenacity>=8.2.3 @@ -44,9 +45,11 @@ isort>=5.13.2 flake8>=7.0.0 mypy>=1.8.0 -# Security +# Security and authentication (pinned for reproducibility) python-jose==3.4.0 +passlib==1.7.4 bcrypt==4.2.1 cryptography==44.0.1 -passlib==1.7.4 + +# Testing utilities freezegun~=1.5.1 \ No newline at end of file diff --git a/backend/tests/api/routes/test_health.py b/backend/tests/api/routes/test_health.py new file mode 100644 index 0000000..d2dd18e --- /dev/null +++ b/backend/tests/api/routes/test_health.py @@ -0,0 +1,200 @@ +# tests/api/routes/test_health.py +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from fastapi import status +from fastapi.testclient import TestClient +from datetime import datetime +from sqlalchemy.exc import OperationalError + +from app.main import app +from app.core.database import get_db + + +@pytest.fixture +def client(): + """Create a FastAPI test client for the main app with mocked database.""" + # Mock get_db to avoid connecting to the actual database + with patch("app.main.get_db") as mock_get_db: + def mock_session_generator(): + mock_session = MagicMock() + # Mock the execute method to return successfully + mock_session.execute.return_value = None + mock_session.close.return_value = None + yield mock_session + + # Return a new generator each time get_db is called + mock_get_db.side_effect = lambda: mock_session_generator() + yield TestClient(app) + + +class TestHealthEndpoint: + """Tests for the /health endpoint""" + + def test_health_check_healthy(self, client): + """Test that health check returns healthy when database is accessible""" + response = client.get("/health") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + + # Check required fields + assert "status" in data + assert data["status"] == "healthy" + assert "timestamp" in data + assert "version" in data + assert "environment" in data + assert "checks" in data + + # Verify timestamp format (ISO 8601) + assert data["timestamp"].endswith("Z") + # Verify it's a valid datetime + datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00")) + + # Check database health + assert "database" in data["checks"] + assert data["checks"]["database"]["status"] == "healthy" + assert "message" in data["checks"]["database"] + + def test_health_check_response_structure(self, client): + """Test that health check response has correct structure""" + response = client.get("/health") + data = response.json() + + # Verify top-level structure + assert isinstance(data["status"], str) + assert isinstance(data["timestamp"], str) + assert isinstance(data["version"], str) + assert isinstance(data["environment"], str) + assert isinstance(data["checks"], dict) + + # Verify database check structure + db_check = data["checks"]["database"] + assert isinstance(db_check["status"], str) + assert isinstance(db_check["message"], str) + + def test_health_check_version_matches_settings(self, client): + """Test that health check returns correct version from settings""" + from app.core.config import settings + + response = client.get("/health") + data = response.json() + + assert data["version"] == settings.VERSION + + def test_health_check_environment_matches_settings(self, client): + """Test that health check returns correct environment from settings""" + from app.core.config import settings + + response = client.get("/health") + data = response.json() + + assert data["environment"] == settings.ENVIRONMENT + + def test_health_check_database_connection_failure(self, client): + """Test that health check returns unhealthy when database is not accessible""" + # Mock the database session to raise an exception + with patch("app.main.get_db") as mock_get_db: + def mock_session(): + from unittest.mock import MagicMock + mock = MagicMock() + mock.execute.side_effect = OperationalError( + "Connection refused", + params=None, + orig=Exception("Connection refused") + ) + yield mock + + mock_get_db.return_value = mock_session() + + response = client.get("/health") + + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + data = response.json() + + # Check overall status + assert data["status"] == "unhealthy" + + # Check database status + assert "database" in data["checks"] + assert data["checks"]["database"]["status"] == "unhealthy" + assert "failed" in data["checks"]["database"]["message"].lower() + + def test_health_check_timestamp_recent(self, client): + """Test that health check timestamp is recent (within last minute)""" + before = datetime.utcnow() + response = client.get("/health") + after = datetime.utcnow() + + data = response.json() + timestamp = datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00")) + + # Timestamp should be between before and after + assert before <= timestamp.replace(tzinfo=None) <= after + + def test_health_check_no_authentication_required(self, client): + """Test that health check does not require authentication""" + # Make request without any authentication headers + response = client.get("/health") + + # Should succeed without authentication + assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE] + + def test_health_check_idempotent(self, client): + """Test that multiple health checks return consistent results""" + response1 = client.get("/health") + response2 = client.get("/health") + + # Both should have same status code (either both healthy or both unhealthy) + assert response1.status_code == response2.status_code + + data1 = response1.json() + data2 = response2.json() + + # Same overall health status + assert data1["status"] == data2["status"] + + # Same version and environment + assert data1["version"] == data2["version"] + assert data1["environment"] == data2["environment"] + + # Same database check status + assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"] + + def test_health_check_content_type(self, client): + """Test that health check returns JSON content type""" + response = client.get("/health") + + assert "application/json" in response.headers["content-type"] + + +class TestHealthEndpointEdgeCases: + """Edge case tests for the /health endpoint""" + + def test_health_check_with_query_parameters(self, client): + """Test that health check ignores query parameters""" + response = client.get("/health?foo=bar&baz=qux") + + # Should still work with query params + assert response.status_code == status.HTTP_200_OK + + def test_health_check_method_not_allowed(self, client): + """Test that POST/PUT/DELETE are not allowed on health endpoint""" + # POST should not be allowed + response = client.post("/health") + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + + # PUT should not be allowed + response = client.put("/health") + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + + # DELETE should not be allowed + response = client.delete("/health") + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + + def test_health_check_with_accept_header(self, client): + """Test that health check works with different Accept headers""" + response = client.get("/health", headers={"Accept": "application/json"}) + assert response.status_code == status.HTTP_200_OK + + response = client.get("/health", headers={"Accept": "*/*"}) + assert response.status_code == status.HTTP_200_OK diff --git a/backend/tests/api/routes/test_rate_limiting.py b/backend/tests/api/routes/test_rate_limiting.py new file mode 100644 index 0000000..97ad318 --- /dev/null +++ b/backend/tests/api/routes/test_rate_limiting.py @@ -0,0 +1,194 @@ +# tests/api/routes/test_rate_limiting.py +import pytest +from fastapi import FastAPI, status +from fastapi.testclient import TestClient +from unittest.mock import patch, MagicMock + +from app.api.routes.auth import router as auth_router, limiter +from app.core.database import get_db + + +# Mock the get_db dependency +@pytest.fixture +def override_get_db(): + """Override get_db dependency for testing.""" + mock_db = MagicMock() + return mock_db + + +@pytest.fixture +def app(override_get_db): + """Create a FastAPI test application with rate limiting.""" + from slowapi import _rate_limit_exceeded_handler + from slowapi.errors import RateLimitExceeded + + app = FastAPI() + app.state.limiter = limiter + app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + app.include_router(auth_router, prefix="/auth", tags=["auth"]) + + # Override the get_db dependency + app.dependency_overrides[get_db] = lambda: override_get_db + + return app + + +@pytest.fixture +def client(app): + """Create a FastAPI test client.""" + return TestClient(app) + + +class TestRegisterRateLimiting: + """Tests for rate limiting on /register endpoint""" + + def test_register_rate_limit_blocks_over_limit(self, client): + """Test that requests over rate limit are blocked""" + from app.services.auth_service import AuthService + from app.models.user import User + from datetime import datetime, timezone + import uuid + + mock_user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash="hashed", + first_name="Test", + last_name="User", + is_active=True, + is_superuser=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + with patch.object(AuthService, 'create_user', return_value=mock_user): + user_data = { + "email": f"test{uuid.uuid4()}@example.com", + "password": "TestPassword123!", + "first_name": "Test", + "last_name": "User" + } + + # Make 6 requests (limit is 5/minute) + responses = [] + for i in range(6): + response = client.post("/auth/register", json=user_data) + responses.append(response) + + # Last request should be rate limited + assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS + + +class TestLoginRateLimiting: + """Tests for rate limiting on /login endpoint""" + + def test_login_rate_limit_blocks_over_limit(self, client): + """Test that login requests over rate limit are blocked""" + from app.services.auth_service import AuthService + + with patch.object(AuthService, 'authenticate_user', return_value=None): + login_data = { + "email": "test@example.com", + "password": "wrong_password" + } + + # Make 11 requests (limit is 10/minute) + responses = [] + for i in range(11): + response = client.post("/auth/login", json=login_data) + responses.append(response) + + # Last request should be rate limited + assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS + + +class TestRefreshTokenRateLimiting: + """Tests for rate limiting on /refresh endpoint""" + + def test_refresh_rate_limit_blocks_over_limit(self, client): + """Test that refresh requests over rate limit are blocked""" + from app.services.auth_service import AuthService + from app.core.auth import TokenInvalidError + + with patch.object(AuthService, 'refresh_tokens', side_effect=TokenInvalidError("Invalid")): + refresh_data = { + "refresh_token": "invalid_token" + } + + # Make 31 requests (limit is 30/minute) + responses = [] + for i in range(31): + response = client.post("/auth/refresh", json=refresh_data) + responses.append(response) + + # Last request should be rate limited + assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS + + +class TestChangePasswordRateLimiting: + """Tests for rate limiting on /change-password endpoint""" + + def test_change_password_rate_limit_blocks_over_limit(self, client): + """Test that change password requests over rate limit are blocked""" + from app.api.dependencies.auth import get_current_user + from app.models.user import User + from app.services.auth_service import AuthService, AuthenticationError + from datetime import datetime, timezone + import uuid + + # Mock current user + mock_user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash="hashed", + first_name="Test", + last_name="User", + is_active=True, + is_superuser=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + # Override get_current_user dependency in the app + test_app = client.app + test_app.dependency_overrides[get_current_user] = lambda: mock_user + + with patch.object(AuthService, 'change_password', side_effect=AuthenticationError("Invalid password")): + password_data = { + "current_password": "wrong_password", + "new_password": "NewPassword123!" + } + + # Make 6 requests (limit is 5/minute) + responses = [] + for i in range(6): + response = client.post("/auth/change-password", json=password_data) + responses.append(response) + + # Last request should be rate limited + assert responses[-1].status_code == status.HTTP_429_TOO_MANY_REQUESTS + + # Clean up override + test_app.dependency_overrides.clear() + + +class TestRateLimitErrorResponse: + """Tests for rate limit error response format""" + + def test_rate_limit_error_response_format(self, client): + """Test that rate limit error has correct format""" + from app.services.auth_service import AuthService + + with patch.object(AuthService, 'authenticate_user', return_value=None): + login_data = { + "email": "test@example.com", + "password": "password" + } + + # Exceed rate limit + for i in range(11): + response = client.post("/auth/login", json=login_data) + + # Check error response + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + assert "detail" in response.json() or "error" in response.json()