Add foundational user authentication and registration system
Introduces schemas for user management, token handling, and password hashing. Implements routes for user registration, login, token refresh, and user info retrieval. Sets up authentication dependencies and integrates the API router with the application.
This commit is contained in:
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
6
backend/app/api/main.py
Normal file
6
backend/app/api/main.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, tags=["auth"])
|
||||
0
backend/app/api/routes/__init__.py
Normal file
0
backend/app/api/routes/__init__.py
Normal file
182
backend/app/api/routes/auth.py
Normal file
182
backend/app/api/routes/auth.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.security import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
create_tokens,
|
||||
decode_token,
|
||||
)
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.token import TokenResponse, RefreshToken
|
||||
from app.schemas.user import UserCreate, UserResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current user based on the JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token from authorization header
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object if valid token
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is invalid or user not found
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
if payload.type != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token type"
|
||||
)
|
||||
user = db.query(User).filter(User.id == payload.sub).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found"
|
||||
)
|
||||
return user
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Could not validate credentials: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def authenticate_user(
|
||||
email: str,
|
||||
password: str,
|
||||
db: Session
|
||||
) -> User:
|
||||
"""
|
||||
Authenticate a user by email and password.
|
||||
|
||||
Args:
|
||||
email: User's email
|
||||
password: User's password
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object if authentication successful, None otherwise
|
||||
"""
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
if not user or not verify_password(password, user.password_hash):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
user_create: UserCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Register a new user.
|
||||
"""
|
||||
# Check if user already exists
|
||||
if db.query(User).filter(User.email == user_create.email).first():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
email=user_create.email,
|
||||
password_hash=get_password_hash(user_create.password),
|
||||
first_name=user_create.first_name,
|
||||
last_name=user_create.last_name,
|
||||
phone_number=user_create.phone_number
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
OAuth2 compatible token login, get an access token for future requests.
|
||||
"""
|
||||
user = await authenticate_user(form_data.username, form_data.password, db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return create_tokens(str(user.id))
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_token(
|
||||
refresh_token: RefreshToken,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Refresh access token using refresh token.
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(refresh_token.refresh_token)
|
||||
|
||||
# Validate token type
|
||||
if payload.type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token type"
|
||||
)
|
||||
|
||||
# Verify user still exists and is active
|
||||
user = db.query(User).filter(User.id == payload.sub).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive"
|
||||
)
|
||||
|
||||
return create_tokens(str(user.id))
|
||||
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Invalid refresh token: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def read_users_me(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
"""
|
||||
Get current user information.
|
||||
"""
|
||||
return current_user
|
||||
0
backend/app/auth/__init__.py
Normal file
0
backend/app/auth/__init__.py
Normal file
45
backend/app/auth/dependencies.py
Normal file
45
backend/app/auth/dependencies.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth.security import SECRET_KEY, ALGORITHM
|
||||
from app.core.database import get_db
|
||||
from models.user import User
|
||||
from app.schemas.token import TokenData
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
token_type: str = payload.get("type")
|
||||
|
||||
if user_id is None or token_type != "access":
|
||||
raise credentials_exception
|
||||
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
user = await db.get(User, user_id)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
111
backend/app/auth/security.py
Normal file
111
backend/app/auth/security.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
from .token import TokenPayload, TokenResponse
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plain password against its hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate password hash."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_tokens(user_id: str) -> TokenResponse:
|
||||
"""
|
||||
Create both access and refresh tokens for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
TokenResponse containing both tokens and metadata
|
||||
"""
|
||||
access_token = create_access_token({"sub": user_id})
|
||||
refresh_token = create_refresh_token({"sub": user_id})
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def create_token(
|
||||
data: dict,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
token_type: str = "access"
|
||||
) -> str:
|
||||
"""Create a JWT token with the specified type and expiration."""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + (
|
||||
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
|
||||
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
)
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"type": token_type,
|
||||
"iat": datetime.utcnow(),
|
||||
"jti": str(uuid4())
|
||||
})
|
||||
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload:
|
||||
"""
|
||||
Decode and validate a JWT token.
|
||||
|
||||
Args:
|
||||
token: The JWT token to decode
|
||||
|
||||
Returns:
|
||||
TokenPayload containing the decoded data
|
||||
|
||||
Raises:
|
||||
JWTError: If token is invalid or expired
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return TokenPayload(
|
||||
sub=payload["sub"],
|
||||
type=payload["type"],
|
||||
exp=datetime.fromtimestamp(payload["exp"]),
|
||||
iat=datetime.fromtimestamp(payload["iat"]),
|
||||
jti=payload.get("jti")
|
||||
)
|
||||
except JWTError as e:
|
||||
raise JWTError(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new access token."""
|
||||
return create_token(data, expires_delta, "access")
|
||||
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new refresh token."""
|
||||
return create_token(data, expires_delta, "refresh")
|
||||
@@ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
from app.api.main import api_router
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,3 +39,5 @@ async def root():
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
54
backend/app/schemas/token.py
Normal file
54
backend/app/schemas/token.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TokenBase(BaseModel):
|
||||
"""Base token schema with common attributes."""
|
||||
token_type: str = Field(default="bearer", description="Type of authentication token")
|
||||
expires_in: int = Field(description="Token expiration time in seconds")
|
||||
|
||||
|
||||
class Token(TokenBase):
|
||||
"""Schema for authentication response containing both access and refresh tokens."""
|
||||
access_token: str = Field(description="JWT access token")
|
||||
refresh_token: str = Field(description="JWT refresh token for obtaining new access tokens")
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""Schema representing the decoded JWT token payload."""
|
||||
sub: str = Field(description="Subject identifier (user ID)")
|
||||
type: str = Field(description="Token type (access or refresh)")
|
||||
exp: datetime = Field(description="Token expiration timestamp")
|
||||
iat: datetime = Field(description="Token issued at timestamp")
|
||||
jti: Optional[str] = Field(None, description="JWT ID - unique identifier for the token")
|
||||
|
||||
|
||||
class RefreshToken(BaseModel):
|
||||
"""Schema for refresh token requests."""
|
||||
refresh_token: str = Field(
|
||||
...,
|
||||
description="JWT refresh token used to obtain new access tokens"
|
||||
)
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Schema for detailed token information response."""
|
||||
access_token: str = Field(description="JWT access token")
|
||||
refresh_token: str = Field(description="JWT refresh token")
|
||||
token_type: str = Field(default="bearer")
|
||||
expires_in: int = Field(description="Token expiration time in seconds")
|
||||
scope: Optional[str] = Field(None, description="Token scope")
|
||||
user_id: str = Field(description="ID of the authenticated user")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 1800,
|
||||
"scope": "read write",
|
||||
"user_id": "123e4567-e89b-12d3-a456-426614174000"
|
||||
}
|
||||
}
|
||||
66
backend/app/schemas/user.py
Normal file
66
backend/app/schemas/user.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
from datetime import datetime
|
||||
from passlib.hash import bcrypt
|
||||
|
||||
|
||||
# Base schema with shared user attributes
|
||||
class UserBase(BaseModel):
|
||||
"""Base schema with common user attributes."""
|
||||
email: EmailStr
|
||||
first_name: str
|
||||
last_name: str
|
||||
|
||||
|
||||
# Schema for creating a new user
|
||||
class UserCreate(UserBase):
|
||||
"""Schema for user registration."""
|
||||
password: str = Field(
|
||||
...,
|
||||
min_length=8,
|
||||
description="Password must be at least 8 characters"
|
||||
)
|
||||
|
||||
@field_validator('password')
|
||||
def password_strength(cls, v):
|
||||
# Add more complex password validation if needed
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
return v
|
||||
|
||||
def hash_password(self) -> str:
|
||||
"""Hash the password before saving it to the database."""
|
||||
return bcrypt.hash(self.password)
|
||||
|
||||
|
||||
# Schema for updating user details
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating user information."""
|
||||
email: Optional[EmailStr] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
preferences: Optional[dict] = None # Provide preferences support
|
||||
|
||||
|
||||
# Schema for user responses (read-only fields)
|
||||
class UserResponse(UserBase):
|
||||
"""Schema for user responses in API."""
|
||||
id: UUID
|
||||
is_active: bool
|
||||
is_superuser: bool # Include roles or superuser flags if needed
|
||||
preferences: Optional[dict] = None # Include preferences in response
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True # Enable mapping SQLAlchemy models to this schema
|
||||
|
||||
|
||||
# Schema for user authentication (e.g., login requests)
|
||||
class UserAuth(BaseModel):
|
||||
"""Schema for user authentication."""
|
||||
email: EmailStr
|
||||
password: str
|
||||
Reference in New Issue
Block a user