Add authentication routes and tests for API
Implemented comprehensive authentication endpoints including user registration, login, token refresh, password change, and user info retrieval. Added extensive test cases for these endpoints to ensure functionality and error handling.
This commit is contained in:
@@ -3,4 +3,4 @@ from fastapi import APIRouter
|
|||||||
from app.api.routes import auth
|
from app.api.routes import auth
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
api_router.include_router(auth.router, tags=["auth"])
|
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||||
|
|||||||
@@ -1,3 +1,231 @@
|
|||||||
from fastapi import APIRouter
|
# app/api/routes/auth.py
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Body
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.users import (
|
||||||
|
UserCreate,
|
||||||
|
UserResponse,
|
||||||
|
Token,
|
||||||
|
LoginRequest,
|
||||||
|
RefreshTokenRequest
|
||||||
|
)
|
||||||
|
from app.services.auth_service import AuthService, AuthenticationError
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register_user(
|
||||||
|
user_data: UserCreate,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Register a new user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created user information.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user = AuthService.create_user(db, user_data)
|
||||||
|
return user
|
||||||
|
except AuthenticationError as e:
|
||||||
|
logger.warning(f"Registration failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during registration: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=Token)
|
||||||
|
async def login(
|
||||||
|
login_data: LoginRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Login with username and password.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Access and refresh tokens.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Attempt to authenticate the user
|
||||||
|
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||||
|
|
||||||
|
# Explicitly check for None result and raise correct exception
|
||||||
|
if user is None:
|
||||||
|
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid email or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# User is authenticated, generate tokens
|
||||||
|
tokens = AuthService.create_tokens(user)
|
||||||
|
logger.info(f"User login successful: {user.email}")
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
# Re-raise HTTP exceptions without modification
|
||||||
|
raise
|
||||||
|
except AuthenticationError as e:
|
||||||
|
# Handle specific authentication errors like inactive accounts
|
||||||
|
logger.warning(f"Authentication failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=str(e),
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Handle unexpected errors
|
||||||
|
logger.error(f"Unexpected error during login: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login/oauth", response_model=Token)
|
||||||
|
async def login_oauth(
|
||||||
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Access and refresh tokens.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid email or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate tokens
|
||||||
|
tokens = AuthService.create_tokens(user)
|
||||||
|
|
||||||
|
# Format response for OAuth compatibility
|
||||||
|
return {
|
||||||
|
"access_token": tokens.access_token,
|
||||||
|
"refresh_token": tokens.refresh_token,
|
||||||
|
"token_type": tokens.token_type
|
||||||
|
}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except AuthenticationError as e:
|
||||||
|
logger.warning(f"OAuth authentication failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=str(e),
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during OAuth login: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=Token)
|
||||||
|
async def refresh_token(
|
||||||
|
refresh_data: RefreshTokenRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Refresh access token using a refresh token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New access and refresh tokens.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||||
|
return tokens
|
||||||
|
except TokenExpiredError:
|
||||||
|
logger.warning("Token refresh failed: Token expired")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Refresh token has expired. Please log in again.",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except TokenInvalidError:
|
||||||
|
logger.warning("Token refresh failed: Invalid token")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/change-password", status_code=status.HTTP_200_OK)
|
||||||
|
async def change_password(
|
||||||
|
current_password: str = Body(..., embed=True),
|
||||||
|
new_password: str = Body(..., embed=True),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Change current user's password.
|
||||||
|
|
||||||
|
Requires authentication.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success = AuthService.change_password(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
current_password=current_password,
|
||||||
|
new_password=new_password
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return {"message": "Password changed successfully"}
|
||||||
|
except AuthenticationError as e:
|
||||||
|
logger.warning(f"Password change failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during password change: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An unexpected error occurred. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_current_user_info(
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get current user information.
|
||||||
|
|
||||||
|
Requires authentication.
|
||||||
|
"""
|
||||||
|
return current_user
|
||||||
|
|||||||
0
backend/tests/api/routes/__init__.py
Normal file
0
backend/tests/api/routes/__init__.py
Normal file
369
backend/tests/api/routes/test_auth.py
Normal file
369
backend/tests/api/routes/test_auth.py
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
# tests/api/routes/test_auth.py
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import patch, MagicMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, Depends
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.api.routes.auth import router as auth_router
|
||||||
|
from app.core.auth import get_password_hash
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.auth_service import AuthService, AuthenticationError
|
||||||
|
from app.core.auth import TokenExpiredError, TokenInvalidError
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the get_db dependency
|
||||||
|
@pytest.fixture
|
||||||
|
def override_get_db(db_session):
|
||||||
|
"""Override get_db dependency for testing."""
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(override_get_db):
|
||||||
|
"""Create a FastAPI test application with overridden dependencies."""
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(auth_router, prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
# Override the get_db dependency
|
||||||
|
app.dependency_overrides[get_db] = lambda: override_get_db
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app):
|
||||||
|
"""Create a FastAPI test client."""
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegisterUser:
|
||||||
|
"""Tests for the register_user endpoint."""
|
||||||
|
|
||||||
|
def test_register_user_success(self, client, monkeypatch, db_session):
|
||||||
|
"""Test successful user registration."""
|
||||||
|
# Mock the service method with a valid complete User object
|
||||||
|
mock_user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="newuser@example.com",
|
||||||
|
password_hash="hashed_password",
|
||||||
|
first_name="New",
|
||||||
|
last_name="User",
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use patch for mocking
|
||||||
|
with patch.object(AuthService, 'create_user', return_value=mock_user):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "newuser@example.com",
|
||||||
|
"password": "Password123",
|
||||||
|
"first_name": "New",
|
||||||
|
"last_name": "User"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert data["email"] == "newuser@example.com"
|
||||||
|
assert data["first_name"] == "New"
|
||||||
|
assert data["last_name"] == "User"
|
||||||
|
assert "password" not in data
|
||||||
|
|
||||||
|
def test_register_user_duplicate_email(self, client, db_session):
|
||||||
|
"""Test registration with duplicate email."""
|
||||||
|
# Use patch for mocking with a side effect
|
||||||
|
with patch.object(AuthService, 'create_user',
|
||||||
|
side_effect=AuthenticationError("User with this email already exists")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/register",
|
||||||
|
json={
|
||||||
|
"email": "existing@example.com",
|
||||||
|
"password": "Password123",
|
||||||
|
"first_name": "Existing",
|
||||||
|
"last_name": "User"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 409
|
||||||
|
assert "already exists" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogin:
|
||||||
|
"""Tests for the login endpoint."""
|
||||||
|
|
||||||
|
def test_login_success(self, client, mock_user, db_session):
|
||||||
|
"""Test successful login."""
|
||||||
|
# Ensure mock_user has required attributes
|
||||||
|
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||||
|
mock_user.created_at = datetime.now(timezone.utc)
|
||||||
|
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||||
|
mock_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Create mock tokens
|
||||||
|
mock_tokens = MagicMock(
|
||||||
|
access_token="mock_access_token",
|
||||||
|
refresh_token="mock_refresh_token",
|
||||||
|
token_type="bearer"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use context managers for patching
|
||||||
|
with patch.object(AuthService, 'authenticate_user', return_value=mock_user), \
|
||||||
|
patch.object(AuthService, 'create_tokens', return_value=mock_tokens):
|
||||||
|
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "user@example.com",
|
||||||
|
"password": "Password123"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["access_token"] == "mock_access_token"
|
||||||
|
assert data["refresh_token"] == "mock_refresh_token"
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_invalid_credentials_debug(self, client, app):
|
||||||
|
"""Improved test for login with invalid credentials."""
|
||||||
|
# Print response for debugging
|
||||||
|
from app.services.auth_service import AuthService
|
||||||
|
|
||||||
|
# Create a complete mock for AuthService
|
||||||
|
class MockAuthService:
|
||||||
|
@staticmethod
|
||||||
|
def authenticate_user(db, email, password):
|
||||||
|
print(f"Mock called with: {email}, {password}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Replace the entire class with our mock
|
||||||
|
original_service = AuthService
|
||||||
|
try:
|
||||||
|
# Replace with our mock
|
||||||
|
import sys
|
||||||
|
sys.modules['app.services.auth_service'].AuthService = MockAuthService
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "user@example.com",
|
||||||
|
"password": "WrongPassword"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print response details for debugging
|
||||||
|
print(f"Response status: {response.status_code}")
|
||||||
|
print(f"Response body: {response.text}")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid email or password" in response.json()["detail"]
|
||||||
|
finally:
|
||||||
|
# Restore original service
|
||||||
|
sys.modules['app.services.auth_service'].AuthService = original_service
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_inactive_user(self, client, db_session):
|
||||||
|
"""Test login with inactive user."""
|
||||||
|
# Mock authentication to raise an error
|
||||||
|
with patch.object(AuthService, 'authenticate_user',
|
||||||
|
side_effect=AuthenticationError("User account is inactive")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "inactive@example.com",
|
||||||
|
"password": "Password123"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "inactive" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRefreshToken:
|
||||||
|
"""Tests for the refresh_token endpoint."""
|
||||||
|
|
||||||
|
def test_refresh_token_success(self, client, db_session):
|
||||||
|
"""Test successful token refresh."""
|
||||||
|
# Mock refresh to return tokens
|
||||||
|
mock_tokens = MagicMock(
|
||||||
|
access_token="new_access_token",
|
||||||
|
refresh_token="new_refresh_token",
|
||||||
|
token_type="bearer"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(AuthService, 'refresh_tokens', return_value=mock_tokens):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/refresh",
|
||||||
|
json={
|
||||||
|
"refresh_token": "valid_refresh_token"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["access_token"] == "new_access_token"
|
||||||
|
assert data["refresh_token"] == "new_refresh_token"
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
|
||||||
|
def test_refresh_token_expired(self, client, db_session):
|
||||||
|
"""Test refresh with expired token."""
|
||||||
|
# Mock refresh to raise expired token error
|
||||||
|
with patch.object(AuthService, 'refresh_tokens',
|
||||||
|
side_effect=TokenExpiredError("Token expired")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/refresh",
|
||||||
|
json={
|
||||||
|
"refresh_token": "expired_refresh_token"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "expired" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_refresh_token_invalid(self, client, db_session):
|
||||||
|
"""Test refresh with invalid token."""
|
||||||
|
# Mock refresh to raise invalid token error
|
||||||
|
with patch.object(AuthService, 'refresh_tokens',
|
||||||
|
side_effect=TokenInvalidError("Invalid token")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/refresh",
|
||||||
|
json={
|
||||||
|
"refresh_token": "invalid_refresh_token"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestChangePassword:
|
||||||
|
"""Tests for the change_password endpoint."""
|
||||||
|
|
||||||
|
def test_change_password_success(self, client, mock_user, db_session, app):
|
||||||
|
"""Test successful password change."""
|
||||||
|
# Ensure mock_user has required attributes
|
||||||
|
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||||
|
mock_user.created_at = datetime.now(timezone.utc)
|
||||||
|
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||||
|
mock_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Override get_current_user dependency
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
|
||||||
|
# Mock password change to return success
|
||||||
|
with patch.object(AuthService, 'change_password', return_value=True):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/change-password",
|
||||||
|
json={
|
||||||
|
"current_password": "OldPassword123",
|
||||||
|
"new_password": "NewPassword123"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "success" in response.json()["message"].lower()
|
||||||
|
|
||||||
|
# Clean up override
|
||||||
|
if get_current_user in app.dependency_overrides:
|
||||||
|
del app.dependency_overrides[get_current_user]
|
||||||
|
|
||||||
|
def test_change_password_incorrect_current_password(self, client, mock_user, db_session, app):
|
||||||
|
"""Test change password with incorrect current password."""
|
||||||
|
# Ensure mock_user has required attributes
|
||||||
|
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||||
|
mock_user.created_at = datetime.now(timezone.utc)
|
||||||
|
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||||
|
mock_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Override get_current_user dependency
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
|
||||||
|
# Mock password change to raise error
|
||||||
|
with patch.object(AuthService, 'change_password',
|
||||||
|
side_effect=AuthenticationError("Current password is incorrect")):
|
||||||
|
# Test request
|
||||||
|
response = client.post(
|
||||||
|
"/auth/change-password",
|
||||||
|
json={
|
||||||
|
"current_password": "WrongPassword",
|
||||||
|
"new_password": "NewPassword123"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "incorrect" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
# Clean up override
|
||||||
|
if get_current_user in app.dependency_overrides:
|
||||||
|
del app.dependency_overrides[get_current_user]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentUserInfo:
|
||||||
|
"""Tests for the get_current_user_info endpoint."""
|
||||||
|
|
||||||
|
def test_get_current_user_info(self, client, mock_user, app):
|
||||||
|
"""Test getting current user info."""
|
||||||
|
# Ensure mock_user has required attributes
|
||||||
|
if not hasattr(mock_user, 'created_at') or mock_user.created_at is None:
|
||||||
|
mock_user.created_at = datetime.now(timezone.utc)
|
||||||
|
if not hasattr(mock_user, 'updated_at') or mock_user.updated_at is None:
|
||||||
|
mock_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Override get_current_user dependency
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
|
||||||
|
# Test request
|
||||||
|
response = client.get("/auth/me")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["email"] == mock_user.email
|
||||||
|
assert data["first_name"] == mock_user.first_name
|
||||||
|
assert data["last_name"] == mock_user.last_name
|
||||||
|
assert "password" not in data
|
||||||
|
|
||||||
|
# Clean up override
|
||||||
|
if get_current_user in app.dependency_overrides:
|
||||||
|
del app.dependency_overrides[get_current_user]
|
||||||
|
|
||||||
|
def test_get_current_user_info_unauthorized(self, client):
|
||||||
|
"""Test getting user info without authentication."""
|
||||||
|
# Test request without authentication
|
||||||
|
response = client.get("/auth/me")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert response.status_code == 401
|
||||||
Reference in New Issue
Block a user