Add foundational user authentication and registration system

Introduces schemas for user management, token handling, and password hashing. Implements routes for user registration, login, token refresh, and user info retrieval. Sets up authentication dependencies and integrates the API router with the application.
This commit is contained in:
2025-02-28 16:18:03 +01:00
parent 290d91d395
commit 43df9d73b0
11 changed files with 467 additions and 1 deletions

View File

6
backend/app/api/main.py Normal file
View File

@@ -0,0 +1,6 @@
from fastapi import APIRouter
from app.api.routes import auth
api_router = APIRouter()
api_router.include_router(auth.router, tags=["auth"])

View File

View File

@@ -0,0 +1,182 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError
from sqlalchemy.orm import Session
from app.auth.security import (
verify_password,
get_password_hash,
create_tokens,
decode_token,
)
from app.core.database import get_db
from app.models.user import User
from app.schemas.token import TokenResponse, RefreshToken
from app.schemas.user import UserCreate, UserResponse
router = APIRouter()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
db: Session = Depends(get_db)
) -> User:
"""
Get the current user based on the JWT token.
Args:
token: JWT token from authorization header
db: Database session
Returns:
User object if valid token
Raises:
HTTPException: If token is invalid or user not found
"""
try:
payload = decode_token(token)
if payload.type != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token type"
)
user = db.query(User).filter(User.id == payload.sub).first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
)
return user
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Could not validate credentials: {str(e)}"
)
async def authenticate_user(
email: str,
password: str,
db: Session
) -> User:
"""
Authenticate a user by email and password.
Args:
email: User's email
password: User's password
db: Database session
Returns:
User object if authentication successful, None otherwise
"""
user = db.query(User).filter(User.email == email).first()
if not user or not verify_password(password, user.password_hash):
return None
return user
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(
user_create: UserCreate,
db: Session = Depends(get_db)
):
"""
Register a new user.
"""
# Check if user already exists
if db.query(User).filter(User.email == user_create.email).first():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
# Create new user
user = User(
email=user_create.email,
password_hash=get_password_hash(user_create.password),
first_name=user_create.first_name,
last_name=user_create.last_name,
phone_number=user_create.phone_number
)
db.add(user)
db.commit()
db.refresh(user)
return user
@router.post("/login", response_model=TokenResponse)
async def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Session = Depends(get_db)
):
"""
OAuth2 compatible token login, get an access token for future requests.
"""
user = await authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user"
)
return create_tokens(str(user.id))
@router.post("/refresh", response_model=TokenResponse)
async def refresh_token(
refresh_token: RefreshToken,
db: Session = Depends(get_db)
):
"""
Refresh access token using refresh token.
"""
try:
payload = decode_token(refresh_token.refresh_token)
# Validate token type
if payload.type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token type"
)
# Verify user still exists and is active
user = db.query(User).filter(User.id == payload.sub).first()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive"
)
return create_tokens(str(user.id))
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid refresh token: {str(e)}"
)
@router.get("/me", response_model=UserResponse)
async def read_users_me(
current_user: Annotated[User, Depends(get_current_user)]
):
"""
Get current user information.
"""
return current_user

View File

View File

@@ -0,0 +1,45 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession
from auth.security import SECRET_KEY, ALGORITHM
from app.core.database import get_db
from models.user import User
from app.schemas.token import TokenData
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
token_type: str = payload.get("type")
if user_id is None or token_type != "access":
raise credentials_exception
except JWTError:
raise credentials_exception
user = await db.get(User, user_id)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user),
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user

View File

@@ -0,0 +1,111 @@
from datetime import datetime, timedelta
from typing import Optional, Tuple
from uuid import uuid4
from jose import jwt, JWTError
from passlib.context import CryptContext
from app.core.config import settings
from .token import TokenPayload, TokenResponse
# Configuration
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plain password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash."""
return pwd_context.hash(password)
def create_tokens(user_id: str) -> TokenResponse:
"""
Create both access and refresh tokens for a user.
Args:
user_id: The user's ID
Returns:
TokenResponse containing both tokens and metadata
"""
access_token = create_access_token({"sub": user_id})
refresh_token = create_refresh_token({"sub": user_id})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_id=user_id
)
def create_token(
data: dict,
expires_delta: Optional[timedelta] = None,
token_type: str = "access"
) -> str:
"""Create a JWT token with the specified type and expiration."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + (
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
to_encode.update({
"exp": expire,
"type": token_type,
"iat": datetime.utcnow(),
"jti": str(uuid4())
})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def decode_token(token: str) -> TokenPayload:
"""
Decode and validate a JWT token.
Args:
token: The JWT token to decode
Returns:
TokenPayload containing the decoded data
Raises:
JWTError: If token is invalid or expired
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=datetime.fromtimestamp(payload["exp"]),
iat=datetime.fromtimestamp(payload["iat"]),
jti=payload.get("jti")
)
except JWTError as e:
raise JWTError(f"Invalid token: {str(e)}")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
return create_token(data, expires_delta, "access")
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
return create_token(data, expires_delta, "refresh")

View File

@@ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from app.core.config import settings
from app.api.main import api_router
import logging
logger = logging.getLogger(__name__)
@@ -39,3 +39,5 @@ async def root():
</body>
</html>
"""
app.include_router(api_router, prefix=settings.API_V1_STR)

View File

View File

@@ -0,0 +1,54 @@
from typing import Optional
from datetime import datetime
from pydantic import BaseModel, Field
class TokenBase(BaseModel):
"""Base token schema with common attributes."""
token_type: str = Field(default="bearer", description="Type of authentication token")
expires_in: int = Field(description="Token expiration time in seconds")
class Token(TokenBase):
"""Schema for authentication response containing both access and refresh tokens."""
access_token: str = Field(description="JWT access token")
refresh_token: str = Field(description="JWT refresh token for obtaining new access tokens")
class TokenPayload(BaseModel):
"""Schema representing the decoded JWT token payload."""
sub: str = Field(description="Subject identifier (user ID)")
type: str = Field(description="Token type (access or refresh)")
exp: datetime = Field(description="Token expiration timestamp")
iat: datetime = Field(description="Token issued at timestamp")
jti: Optional[str] = Field(None, description="JWT ID - unique identifier for the token")
class RefreshToken(BaseModel):
"""Schema for refresh token requests."""
refresh_token: str = Field(
...,
description="JWT refresh token used to obtain new access tokens"
)
class TokenResponse(BaseModel):
"""Schema for detailed token information response."""
access_token: str = Field(description="JWT access token")
refresh_token: str = Field(description="JWT refresh token")
token_type: str = Field(default="bearer")
expires_in: int = Field(description="Token expiration time in seconds")
scope: Optional[str] = Field(None, description="Token scope")
user_id: str = Field(description="ID of the authenticated user")
class Config:
json_schema_extra = {
"example": {
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"token_type": "bearer",
"expires_in": 1800,
"scope": "read write",
"user_id": "123e4567-e89b-12d3-a456-426614174000"
}
}

View File

@@ -0,0 +1,66 @@
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, EmailStr, Field, field_validator
from datetime import datetime
from passlib.hash import bcrypt
# Base schema with shared user attributes
class UserBase(BaseModel):
"""Base schema with common user attributes."""
email: EmailStr
first_name: str
last_name: str
# Schema for creating a new user
class UserCreate(UserBase):
"""Schema for user registration."""
password: str = Field(
...,
min_length=8,
description="Password must be at least 8 characters"
)
@field_validator('password')
def password_strength(cls, v):
# Add more complex password validation if needed
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
return v
def hash_password(self) -> str:
"""Hash the password before saving it to the database."""
return bcrypt.hash(self.password)
# Schema for updating user details
class UserUpdate(BaseModel):
"""Schema for updating user information."""
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phone_number: Optional[str] = None
is_active: Optional[bool] = None
preferences: Optional[dict] = None # Provide preferences support
# Schema for user responses (read-only fields)
class UserResponse(UserBase):
"""Schema for user responses in API."""
id: UUID
is_active: bool
is_superuser: bool # Include roles or superuser flags if needed
preferences: Optional[dict] = None # Include preferences in response
created_at: datetime
updated_at: Optional[datetime] = None
class Config:
orm_mode = True # Enable mapping SQLAlchemy models to this schema
# Schema for user authentication (e.g., login requests)
class UserAuth(BaseModel):
"""Schema for user authentication."""
email: EmailStr
password: str