diff --git a/backend/app/api/main.py b/backend/app/api/main.py index ef993e3..6b6f08f 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -3,4 +3,4 @@ from fastapi import APIRouter from app.api.routes import auth api_router = APIRouter() -api_router.include_router(auth.router, tags=["auth"]) +api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index af9233c..d0bbb79 100644 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -1,3 +1,231 @@ -from fastapi import APIRouter +# app/api/routes/auth.py +import logging +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, status, Body +from fastapi.security import OAuth2PasswordRequestForm +from sqlalchemy.orm import Session + +from app.api.dependencies.auth import get_current_user +from app.core.auth import TokenExpiredError, TokenInvalidError +from app.core.database import get_db +from app.models.user import User +from app.schemas.users import ( + UserCreate, + UserResponse, + Token, + LoginRequest, + RefreshTokenRequest +) +from app.services.auth_service import AuthService, AuthenticationError router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) +async def register_user( + user_data: UserCreate, + db: Session = Depends(get_db) +) -> Any: + """ + Register a new user. + + Returns: + The created user information. + """ + try: + user = AuthService.create_user(db, user_data) + return user + except AuthenticationError as e: + logger.warning(f"Registration failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(e) + ) + except Exception as e: + logger.error(f"Unexpected error during registration: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.post("/login", response_model=Token) +async def login( + login_data: LoginRequest, + db: Session = Depends(get_db) +) -> Any: + """ + Login with username and password. + + Returns: + Access and refresh tokens. + """ + try: + # Attempt to authenticate the user + user = AuthService.authenticate_user(db, login_data.email, login_data.password) + + # Explicitly check for None result and raise correct exception + if user is None: + logger.warning(f"Invalid login attempt for: {login_data.email}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid email or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # User is authenticated, generate tokens + tokens = AuthService.create_tokens(user) + logger.info(f"User login successful: {user.email}") + return tokens + + except HTTPException: + # Re-raise HTTP exceptions without modification + raise + except AuthenticationError as e: + # Handle specific authentication errors like inactive accounts + logger.warning(f"Authentication failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(e), + headers={"WWW-Authenticate": "Bearer"}, + ) + except Exception as e: + # Handle unexpected errors + logger.error(f"Unexpected error during login: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.post("/login/oauth", response_model=Token) +async def login_oauth( + form_data: OAuth2PasswordRequestForm = Depends(), + db: Session = Depends(get_db) +) -> Any: + """ + OAuth2-compatible login endpoint, used by the OpenAPI UI. + + Returns: + Access and refresh tokens. + """ + try: + user = AuthService.authenticate_user(db, form_data.username, form_data.password) + + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid email or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Generate tokens + tokens = AuthService.create_tokens(user) + + # Format response for OAuth compatibility + return { + "access_token": tokens.access_token, + "refresh_token": tokens.refresh_token, + "token_type": tokens.token_type + } + except HTTPException: + raise + except AuthenticationError as e: + logger.warning(f"OAuth authentication failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(e), + headers={"WWW-Authenticate": "Bearer"}, + ) + except Exception as e: + logger.error(f"Unexpected error during OAuth login: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.post("/refresh", response_model=Token) +async def refresh_token( + refresh_data: RefreshTokenRequest, + db: Session = Depends(get_db) +) -> Any: + """ + Refresh access token using a refresh token. + + Returns: + New access and refresh tokens. + """ + try: + tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token) + return tokens + except TokenExpiredError: + logger.warning("Token refresh failed: Token expired") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh token has expired. Please log in again.", + headers={"WWW-Authenticate": "Bearer"}, + ) + except TokenInvalidError: + logger.warning("Token refresh failed: Invalid token") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + headers={"WWW-Authenticate": "Bearer"}, + ) + except Exception as e: + logger.error(f"Unexpected error during token refresh: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.post("/change-password", status_code=status.HTTP_200_OK) +async def change_password( + current_password: str = Body(..., embed=True), + new_password: str = Body(..., embed=True), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +) -> Any: + """ + Change current user's password. + + Requires authentication. + """ + try: + success = AuthService.change_password( + db=db, + user_id=current_user.id, + current_password=current_password, + new_password=new_password + ) + + if success: + return {"message": "Password changed successfully"} + except AuthenticationError as e: + logger.warning(f"Password change failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + except Exception as e: + logger.error(f"Unexpected error during password change: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) + + +@router.get("/me", response_model=UserResponse) +async def get_current_user_info( + current_user: User = Depends(get_current_user) +) -> Any: + """ + Get current user information. + + Requires authentication. + """ + return current_user diff --git a/backend/tests/api/routes/__init__.py b/backend/tests/api/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/api/routes/test_auth.py b/backend/tests/api/routes/test_auth.py new file mode 100644 index 0000000..dc3f99f --- /dev/null +++ b/backend/tests/api/routes/test_auth.py @@ -0,0 +1,369 @@ +# tests/api/routes/test_auth.py +import json +import uuid +from datetime import datetime, timezone +from unittest.mock import patch, MagicMock, Mock + +import pytest +from fastapi import FastAPI, Depends +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from app.api.routes.auth import router as auth_router +from app.core.auth import get_password_hash +from app.core.database import get_db +from app.models.user import User +from app.services.auth_service import AuthService, AuthenticationError +from app.core.auth import TokenExpiredError, TokenInvalidError + + +# Mock the get_db dependency +@pytest.fixture +def override_get_db(db_session): + """Override get_db dependency for testing.""" + return db_session + + +@pytest.fixture +def app(override_get_db): + """Create a FastAPI test application with overridden dependencies.""" + app = FastAPI() + app.include_router(auth_router, prefix="/auth", tags=["auth"]) + + # Override the get_db dependency + app.dependency_overrides[get_db] = lambda: override_get_db + + return app + + +@pytest.fixture +def client(app): + """Create a FastAPI test client.""" + return TestClient(app) + + +class TestRegisterUser: + """Tests for the register_user endpoint.""" + + def test_register_user_success(self, client, monkeypatch, db_session): + """Test successful user registration.""" + # Mock the service method with a valid complete User object + mock_user = User( + id=uuid.uuid4(), + email="newuser@example.com", + password_hash="hashed_password", + first_name="New", + last_name="User", + is_active=True, + is_superuser=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + # Use patch for mocking + with patch.object(AuthService, 'create_user', return_value=mock_user): + # Test request + response = client.post( + "/auth/register", + json={ + "email": "newuser@example.com", + "password": "Password123", + "first_name": "New", + "last_name": "User" + } + ) + + # Assertions + assert response.status_code == 201 + data = response.json() + assert data["email"] == "newuser@example.com" + assert data["first_name"] == "New" + assert data["last_name"] == "User" + assert "password" not in data + + def test_register_user_duplicate_email(self, client, db_session): + """Test registration with duplicate email.""" + # Use patch for mocking with a side effect + with patch.object(AuthService, 'create_user', + side_effect=AuthenticationError("User with this email already exists")): + # Test request + response = client.post( + "/auth/register", + json={ + "email": "existing@example.com", + "password": "Password123", + "first_name": "Existing", + "last_name": "User" + } + ) + + # Assertions + assert response.status_code == 409 + assert "already exists" in response.json()["detail"] + + +class TestLogin: + """Tests for the login endpoint.""" + + def test_login_success(self, client, mock_user, db_session): + """Test successful login.""" + # Ensure mock_user has required attributes + if not hasattr(mock_user, 'created_at') or mock_user.created_at is None: + mock_user.created_at = datetime.now(timezone.utc) + if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None: + mock_user.updated_at = datetime.now(timezone.utc) + + # Create mock tokens + mock_tokens = MagicMock( + access_token="mock_access_token", + refresh_token="mock_refresh_token", + token_type="bearer" + ) + + # Use context managers for patching + with patch.object(AuthService, 'authenticate_user', return_value=mock_user), \ + patch.object(AuthService, 'create_tokens', return_value=mock_tokens): + + # Test request + response = client.post( + "/auth/login", + json={ + "email": "user@example.com", + "password": "Password123" + } + ) + + # Assertions + assert response.status_code == 200 + data = response.json() + assert data["access_token"] == "mock_access_token" + assert data["refresh_token"] == "mock_refresh_token" + assert data["token_type"] == "bearer" + + + def test_login_invalid_credentials_debug(self, client, app): + """Improved test for login with invalid credentials.""" + # Print response for debugging + from app.services.auth_service import AuthService + + # Create a complete mock for AuthService + class MockAuthService: + @staticmethod + def authenticate_user(db, email, password): + print(f"Mock called with: {email}, {password}") + return None + + # Replace the entire class with our mock + original_service = AuthService + try: + # Replace with our mock + import sys + sys.modules['app.services.auth_service'].AuthService = MockAuthService + + # Make the request + response = client.post( + "/auth/login", + json={ + "email": "user@example.com", + "password": "WrongPassword" + } + ) + + # Print response details for debugging + print(f"Response status: {response.status_code}") + print(f"Response body: {response.text}") + + # Assertions + assert response.status_code == 401 + assert "Invalid email or password" in response.json()["detail"] + finally: + # Restore original service + sys.modules['app.services.auth_service'].AuthService = original_service + + + def test_login_inactive_user(self, client, db_session): + """Test login with inactive user.""" + # Mock authentication to raise an error + with patch.object(AuthService, 'authenticate_user', + side_effect=AuthenticationError("User account is inactive")): + # Test request + response = client.post( + "/auth/login", + json={ + "email": "inactive@example.com", + "password": "Password123" + } + ) + + # Assertions + assert response.status_code == 401 + assert "inactive" in response.json()["detail"] + + +class TestRefreshToken: + """Tests for the refresh_token endpoint.""" + + def test_refresh_token_success(self, client, db_session): + """Test successful token refresh.""" + # Mock refresh to return tokens + mock_tokens = MagicMock( + access_token="new_access_token", + refresh_token="new_refresh_token", + token_type="bearer" + ) + + with patch.object(AuthService, 'refresh_tokens', return_value=mock_tokens): + # Test request + response = client.post( + "/auth/refresh", + json={ + "refresh_token": "valid_refresh_token" + } + ) + + # Assertions + assert response.status_code == 200 + data = response.json() + assert data["access_token"] == "new_access_token" + assert data["refresh_token"] == "new_refresh_token" + assert data["token_type"] == "bearer" + + def test_refresh_token_expired(self, client, db_session): + """Test refresh with expired token.""" + # Mock refresh to raise expired token error + with patch.object(AuthService, 'refresh_tokens', + side_effect=TokenExpiredError("Token expired")): + # Test request + response = client.post( + "/auth/refresh", + json={ + "refresh_token": "expired_refresh_token" + } + ) + + # Assertions + assert response.status_code == 401 + assert "expired" in response.json()["detail"] + + def test_refresh_token_invalid(self, client, db_session): + """Test refresh with invalid token.""" + # Mock refresh to raise invalid token error + with patch.object(AuthService, 'refresh_tokens', + side_effect=TokenInvalidError("Invalid token")): + # Test request + response = client.post( + "/auth/refresh", + json={ + "refresh_token": "invalid_refresh_token" + } + ) + + # Assertions + assert response.status_code == 401 + assert "Invalid" in response.json()["detail"] + + +class TestChangePassword: + """Tests for the change_password endpoint.""" + + def test_change_password_success(self, client, mock_user, db_session, app): + """Test successful password change.""" + # Ensure mock_user has required attributes + if not hasattr(mock_user, 'created_at') or mock_user.created_at is None: + mock_user.created_at = datetime.now(timezone.utc) + if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None: + mock_user.updated_at = datetime.now(timezone.utc) + + # Override get_current_user dependency + from app.api.dependencies.auth import get_current_user + app.dependency_overrides[get_current_user] = lambda: mock_user + + # Mock password change to return success + with patch.object(AuthService, 'change_password', return_value=True): + # Test request + response = client.post( + "/auth/change-password", + json={ + "current_password": "OldPassword123", + "new_password": "NewPassword123" + } + ) + + # Assertions + assert response.status_code == 200 + assert "success" in response.json()["message"].lower() + + # Clean up override + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_change_password_incorrect_current_password(self, client, mock_user, db_session, app): + """Test change password with incorrect current password.""" + # Ensure mock_user has required attributes + if not hasattr(mock_user, 'created_at') or mock_user.created_at is None: + mock_user.created_at = datetime.now(timezone.utc) + if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None: + mock_user.updated_at = datetime.now(timezone.utc) + + # Override get_current_user dependency + from app.api.dependencies.auth import get_current_user + app.dependency_overrides[get_current_user] = lambda: mock_user + + # Mock password change to raise error + with patch.object(AuthService, 'change_password', + side_effect=AuthenticationError("Current password is incorrect")): + # Test request + response = client.post( + "/auth/change-password", + json={ + "current_password": "WrongPassword", + "new_password": "NewPassword123" + } + ) + + # Assertions + assert response.status_code == 400 + assert "incorrect" in response.json()["detail"].lower() + + # Clean up override + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + +class TestGetCurrentUserInfo: + """Tests for the get_current_user_info endpoint.""" + + def test_get_current_user_info(self, client, mock_user, app): + """Test getting current user info.""" + # Ensure mock_user has required attributes + if not hasattr(mock_user, 'created_at') or mock_user.created_at is None: + mock_user.created_at = datetime.now(timezone.utc) + if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None: + mock_user.updated_at = datetime.now(timezone.utc) + + # Override get_current_user dependency + from app.api.dependencies.auth import get_current_user + app.dependency_overrides[get_current_user] = lambda: mock_user + + # Test request + response = client.get("/auth/me") + + # Assertions + assert response.status_code == 200 + data = response.json() + assert data["email"] == mock_user.email + assert data["first_name"] == mock_user.first_name + assert data["last_name"] == mock_user.last_name + assert "password" not in data + + # Clean up override + if get_current_user in app.dependency_overrides: + del app.dependency_overrides[get_current_user] + + def test_get_current_user_info_unauthorized(self, client): + """Test getting user info without authentication.""" + # Test request without authentication + response = client.get("/auth/me") + + # Assertions + assert response.status_code == 401 \ No newline at end of file