Refactor and enhance token decoding error handling

Improved the `decode_token` function to clarify and extend error handling for token validation and decoding. Enhanced error messages for invalid tokens, added checks for missing claims, and ensured clear differentiation of failure scenarios. Updated imports and added a `scope` field to token response for completeness.
This commit is contained in:
2025-02-28 19:05:08 +01:00
parent 0bc9263d24
commit 453016629f
6 changed files with 93 additions and 25 deletions

View File

@@ -15,7 +15,7 @@ async def get_current_user(
db: AsyncSession = Depends(get_db)
):
try:
payload = decode_token(token) # Use updated decode_token.
payload = await decode_token(token) # Use updated decode_token.
user_id: str = payload.sub
token_type: str = payload.type

View File

@@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional
from uuid import uuid4
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import get_db
from app.schemas.token import TokenPayload, TokenResponse
from auth.utlis import is_token_revoked
from auth.utils import is_token_revoked
# Configuration
SECRET_KEY = settings.SECRET_KEY
@@ -65,9 +65,9 @@ def create_token(
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now() + (
expire = datetime.now(timezone.utc) + (
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
@@ -75,7 +75,7 @@ def create_token(
to_encode.update({
"exp": expire,
"type": token_type,
"iat": datetime.now(),
"iat": datetime.now(timezone.utc),
})
if "jti" not in to_encode:
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
@@ -120,7 +120,16 @@ async def decode_token(
"""
try:
# Step 1: Decode the JWT token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM],
options={
"verify_exp": True,
"verify_iat": True,
"require": ["exp", "iat", "sub", "type", "jti"]
}
)
except ExpiredSignatureError:
raise JWTError("Token has expired. Please refresh your token or login again.")
@@ -144,7 +153,7 @@ async def decode_token(
# Step 3: Validate Expiry
expiration = datetime.fromtimestamp(payload["exp"])
if datetime.now() > expiration:
if datetime.now(timezone.utc) > expiration:
raise JWTError("Token has expired. Please refresh your token or login again.")
# Step 4: Validate Token Type

45
backend/app/auth/utils.py Normal file
View File

@@ -0,0 +1,45 @@
from datetime import datetime, timezone, timedelta
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.token import RevokedToken
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
db.add(revoked_token)
await db.commit()
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
"""Check whether the token's JTI is in the revoked_tokens table."""
from sqlalchemy import select
result = await db.execute(select(RevokedToken).where(RevokedToken.jti == jti))
revoked = result.scalar_one_or_none()
return revoked is not None
async def cleanup_expired_tokens(db: AsyncSession):
"""Delete revoked tokens that are past their expiration time."""
now = datetime.now(timezone.utc)
# For access tokens (shorter expiry)
expire_before = now - timedelta(days=1) # Keep for 1 day past expiry
await db.execute(
delete(RevokedToken).where(
(RevokedToken.token_type == "access") &
(RevokedToken.created_at < expire_before)
)
)
# For refresh tokens (longer expiry)
expire_before = now - timedelta(days=14) # Keep for 14 days past expiry
await db.execute(
delete(RevokedToken).where(
(RevokedToken.token_type == "refresh") &
(RevokedToken.created_at < expire_before)
)
)
await db.commit()

View File

@@ -1,15 +0,0 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.token import RevokedToken
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
db.add(revoked_token)
await db.commit()
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
"""Check whether the token's `jti` is in the revoked_tokens table."""
revoked = await db.get(RevokedToken, jti)
return revoked is not None

View File

@@ -1,3 +1,5 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
@@ -6,6 +8,11 @@ from app.core.config import settings
from app.api.main import api_router
import logging
from auth.utils import cleanup_expired_tokens
from app.core.database import SessionLocal
scheduler = AsyncIOScheduler()
logger = logging.getLogger(__name__)
logger.info(f"Starting app!!!")
@@ -25,6 +32,26 @@ app.add_middleware(
)
# Create a function that gets its own database session
async def scheduled_cleanup():
async with SessionLocal() as db:
await cleanup_expired_tokens(db)
@app.on_event("startup")
async def start_scheduler():
# Run every day at 3 AM
scheduler.add_job(
scheduled_cleanup,
CronTrigger(hour=10, minute=0),
id="token_cleanup",
name="Clean up expired revoked tokens"
)
scheduler.start()
@app.on_event("shutdown")
async def stop_scheduler():
scheduler.shutdown()
@app.get("/", response_class=HTMLResponse)
async def root():
return """

View File

@@ -4,6 +4,7 @@ uvicorn>=0.34.0
pydantic>=2.10.6
pydantic-settings>=2.2.1
python-multipart>=0.0.19
fastapi-utils==0.8.0
# Database
sqlalchemy>=2.0.29
@@ -30,7 +31,7 @@ httpx>=0.27.0
tenacity>=8.2.3
pytz>=2024.1
pillow>=10.3.0
apscheduler==3.11.0
# Testing
pytest>=8.0.0
pytest-asyncio>=0.23.5
@@ -47,4 +48,5 @@ mypy>=1.8.0
python-jose==3.4.0
bcrypt==4.2.1
cryptography==44.0.1
passlib==1.7.4
passlib==1.7.4
freezegun~=1.5.1