Refactor AuthService for better error handling and CRUD usage
Replaced raw database queries with CRUD operations for consistency and modularity. Enhanced error handling by adding detailed exception messages and logging for failed actions, such as authentication and registration. Updated tests to reflect new exception-based error handling approach.
This commit is contained in:
@@ -15,13 +15,18 @@ from app.core.auth import (
|
|||||||
)
|
)
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import Token, UserCreate
|
from app.schemas.users import Token, UserCreate
|
||||||
|
from app.crud.user import user as crud_user
|
||||||
|
from app.core.auth import decode_token, get_token_data
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(Exception):
|
class AuthenticationError(Exception):
|
||||||
"""Exception raised for authentication errors"""
|
"""Raised when authentication fails"""
|
||||||
pass
|
|
||||||
|
def __init__(self, message: str = "Authentication failed"):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
@@ -40,19 +45,22 @@ class AuthService:
|
|||||||
Returns:
|
Returns:
|
||||||
User if authenticated, None otherwise
|
User if authenticated, None otherwise
|
||||||
"""
|
"""
|
||||||
user = db.query(User).filter(User.email == email).first()
|
user = crud_user.get_by_email(db, email=email)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
logger.warning(f"Login attempt failed: user not found for email {email}")
|
||||||
|
raise AuthenticationError("Invalid email or password")
|
||||||
|
|
||||||
if not verify_password(password, user.password_hash):
|
if not verify_password(password, user.password_hash):
|
||||||
return None
|
logger.warning(f"Login attempt failed: invalid password for user {email}")
|
||||||
|
raise AuthenticationError("Invalid email or password")
|
||||||
|
|
||||||
if not user.is_active:
|
if not crud_user.is_active(user):
|
||||||
raise AuthenticationError("User account is inactive")
|
logger.warning(f"Login attempt failed: inactive user {email}")
|
||||||
|
raise AuthenticationError("Inactive user")
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_user(db: Session, user_data: UserCreate) -> User:
|
def create_user(db: Session, user_data: UserCreate) -> User:
|
||||||
"""
|
"""
|
||||||
@@ -66,29 +74,20 @@ class AuthService:
|
|||||||
Created user
|
Created user
|
||||||
"""
|
"""
|
||||||
# Check if user already exists
|
# Check if user already exists
|
||||||
existing_user = db.query(User).filter(User.email == user_data.email).first()
|
existing_user = crud_user.get_by_email(db, email=user_data.email)
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise AuthenticationError("User with this email already exists")
|
logger.warning(f"Registration failed: email already registered {user_data.email}")
|
||||||
|
raise AuthenticationError("Email already registered")
|
||||||
|
|
||||||
# Create new user
|
try:
|
||||||
hashed_password = get_password_hash(user_data.password)
|
# Create new user using CRUD user
|
||||||
|
user = crud_user.create(db, obj_in=user_data)
|
||||||
|
logger.info(f"New user created: {user.email}")
|
||||||
|
return user
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"User creation failed: {str(e)}")
|
||||||
|
raise AuthenticationError("Could not create user")
|
||||||
|
|
||||||
# Create user object from model
|
|
||||||
user = User(
|
|
||||||
email=user_data.email,
|
|
||||||
password_hash=hashed_password,
|
|
||||||
first_name=user_data.first_name,
|
|
||||||
last_name=user_data.last_name,
|
|
||||||
phone_number=user_data.phone_number,
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(user)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(user)
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_tokens(user: User) -> Token:
|
def create_tokens(user: User) -> Token:
|
||||||
@@ -139,7 +138,6 @@ class AuthService:
|
|||||||
TokenExpiredError: If refresh token has expired
|
TokenExpiredError: If refresh token has expired
|
||||||
TokenInvalidError: If refresh token is invalid
|
TokenInvalidError: If refresh token is invalid
|
||||||
"""
|
"""
|
||||||
from app.core.auth import decode_token, get_token_data
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify token is a refresh token
|
# Verify token is a refresh token
|
||||||
@@ -148,12 +146,17 @@ class AuthService:
|
|||||||
# Get user ID from token
|
# Get user ID from token
|
||||||
token_data = get_token_data(refresh_token)
|
token_data = get_token_data(refresh_token)
|
||||||
user_id = token_data.user_id
|
user_id = token_data.user_id
|
||||||
|
if not user_id:
|
||||||
|
raise AuthenticationError("Invalid token")
|
||||||
|
|
||||||
|
|
||||||
# Get user from database
|
# Get user from database
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
user: User | None = crud_user.get(db, id=UUID(str(user_id)))
|
||||||
|
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
raise TokenInvalidError("Invalid user or inactive account")
|
raise TokenInvalidError("Invalid user or inactive account")
|
||||||
|
|
||||||
|
user: User
|
||||||
# Generate new tokens
|
# Generate new tokens
|
||||||
return AuthService.create_tokens(user)
|
return AuthService.create_tokens(user)
|
||||||
|
|
||||||
@@ -178,7 +181,7 @@ class AuthService:
|
|||||||
Raises:
|
Raises:
|
||||||
AuthenticationError: If current password is incorrect
|
AuthenticationError: If current password is incorrect
|
||||||
"""
|
"""
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
user = crud_user.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise AuthenticationError("User not found")
|
raise AuthenticationError("User not found")
|
||||||
|
|
||||||
|
|||||||
@@ -32,13 +32,14 @@ class TestAuthServiceAuthentication:
|
|||||||
|
|
||||||
def test_authenticate_nonexistent_user(self, db_session):
|
def test_authenticate_nonexistent_user(self, db_session):
|
||||||
"""Test authenticating with an email that doesn't exist"""
|
"""Test authenticating with an email that doesn't exist"""
|
||||||
user = AuthService.authenticate_user(
|
with pytest.raises(AuthenticationError):
|
||||||
db=db_session,
|
user = AuthService.authenticate_user(
|
||||||
email="nonexistent@example.com",
|
db=db_session,
|
||||||
password="password"
|
email="nonexistent@example.com",
|
||||||
)
|
password="password"
|
||||||
|
)
|
||||||
|
|
||||||
assert user is None
|
assert user is None
|
||||||
|
|
||||||
def test_authenticate_with_wrong_password(self, db_session, mock_user):
|
def test_authenticate_with_wrong_password(self, db_session, mock_user):
|
||||||
"""Test authenticating with the wrong password"""
|
"""Test authenticating with the wrong password"""
|
||||||
@@ -48,13 +49,14 @@ class TestAuthServiceAuthentication:
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
# Authenticate with wrong password
|
# Authenticate with wrong password
|
||||||
user = AuthService.authenticate_user(
|
with pytest.raises(AuthenticationError):
|
||||||
db=db_session,
|
user = AuthService.authenticate_user(
|
||||||
email=mock_user.email,
|
db=db_session,
|
||||||
password="WrongPassword123"
|
email=mock_user.email,
|
||||||
)
|
password="WrongPassword123"
|
||||||
|
)
|
||||||
|
|
||||||
assert user is None
|
assert user is None
|
||||||
|
|
||||||
def test_authenticate_inactive_user(self, db_session, mock_user):
|
def test_authenticate_inactive_user(self, db_session, mock_user):
|
||||||
"""Test authenticating an inactive user"""
|
"""Test authenticating an inactive user"""
|
||||||
|
|||||||
Reference in New Issue
Block a user