Refactor and enhance token decoding error handling
Improved the `decode_token` function to clarify and extend error handling for token validation and decoding. Enhanced error messages for invalid tokens, added checks for missing claims, and ensured clear differentiation of failure scenarios. Updated imports and added a `scope` field to token response for completeness.
This commit is contained in:
@@ -15,7 +15,7 @@ async def get_current_user(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
payload = decode_token(token) # Use updated decode_token.
|
||||
payload = await decode_token(token) # Use updated decode_token.
|
||||
user_id: str = payload.sub
|
||||
token_type: str = payload.type
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.schemas.token import TokenPayload, TokenResponse
|
||||
from auth.utlis import is_token_revoked
|
||||
from auth.utils import is_token_revoked
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
@@ -65,9 +65,9 @@ def create_token(
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now() + expires_delta
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now() + (
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
|
||||
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
)
|
||||
@@ -75,7 +75,7 @@ def create_token(
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"type": token_type,
|
||||
"iat": datetime.now(),
|
||||
"iat": datetime.now(timezone.utc),
|
||||
})
|
||||
if "jti" not in to_encode:
|
||||
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
|
||||
@@ -120,7 +120,16 @@ async def decode_token(
|
||||
"""
|
||||
try:
|
||||
# Step 1: Decode the JWT token
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
SECRET_KEY,
|
||||
algorithms=[ALGORITHM],
|
||||
options={
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "iat", "sub", "type", "jti"]
|
||||
}
|
||||
)
|
||||
|
||||
except ExpiredSignatureError:
|
||||
raise JWTError("Token has expired. Please refresh your token or login again.")
|
||||
@@ -144,7 +153,7 @@ async def decode_token(
|
||||
|
||||
# Step 3: Validate Expiry
|
||||
expiration = datetime.fromtimestamp(payload["exp"])
|
||||
if datetime.now() > expiration:
|
||||
if datetime.now(timezone.utc) > expiration:
|
||||
raise JWTError("Token has expired. Please refresh your token or login again.")
|
||||
|
||||
# Step 4: Validate Token Type
|
||||
|
||||
45
backend/app/auth/utils.py
Normal file
45
backend/app/auth/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.token import RevokedToken
|
||||
|
||||
|
||||
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
|
||||
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
|
||||
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
|
||||
db.add(revoked_token)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
|
||||
"""Check whether the token's JTI is in the revoked_tokens table."""
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(select(RevokedToken).where(RevokedToken.jti == jti))
|
||||
revoked = result.scalar_one_or_none()
|
||||
return revoked is not None
|
||||
|
||||
|
||||
async def cleanup_expired_tokens(db: AsyncSession):
|
||||
"""Delete revoked tokens that are past their expiration time."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# For access tokens (shorter expiry)
|
||||
expire_before = now - timedelta(days=1) # Keep for 1 day past expiry
|
||||
await db.execute(
|
||||
delete(RevokedToken).where(
|
||||
(RevokedToken.token_type == "access") &
|
||||
(RevokedToken.created_at < expire_before)
|
||||
)
|
||||
)
|
||||
|
||||
# For refresh tokens (longer expiry)
|
||||
expire_before = now - timedelta(days=14) # Keep for 14 days past expiry
|
||||
await db.execute(
|
||||
delete(RevokedToken).where(
|
||||
(RevokedToken.token_type == "refresh") &
|
||||
(RevokedToken.created_at < expire_before)
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
@@ -1,15 +0,0 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.token import RevokedToken
|
||||
|
||||
|
||||
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
|
||||
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
|
||||
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
|
||||
db.add(revoked_token)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
|
||||
"""Check whether the token's `jti` is in the revoked_tokens table."""
|
||||
revoked = await db.get(RevokedToken, jti)
|
||||
return revoked is not None
|
||||
@@ -1,3 +1,5 @@
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
@@ -6,6 +8,11 @@ from app.core.config import settings
|
||||
from app.api.main import api_router
|
||||
import logging
|
||||
|
||||
from auth.utils import cleanup_expired_tokens
|
||||
from app.core.database import SessionLocal
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"Starting app!!!")
|
||||
@@ -25,6 +32,26 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
# Create a function that gets its own database session
|
||||
async def scheduled_cleanup():
|
||||
async with SessionLocal() as db:
|
||||
await cleanup_expired_tokens(db)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def start_scheduler():
|
||||
# Run every day at 3 AM
|
||||
scheduler.add_job(
|
||||
scheduled_cleanup,
|
||||
CronTrigger(hour=10, minute=0),
|
||||
id="token_cleanup",
|
||||
name="Clean up expired revoked tokens"
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def stop_scheduler():
|
||||
scheduler.shutdown()
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root():
|
||||
return """
|
||||
|
||||
@@ -4,6 +4,7 @@ uvicorn>=0.34.0
|
||||
pydantic>=2.10.6
|
||||
pydantic-settings>=2.2.1
|
||||
python-multipart>=0.0.19
|
||||
fastapi-utils==0.8.0
|
||||
|
||||
# Database
|
||||
sqlalchemy>=2.0.29
|
||||
@@ -30,7 +31,7 @@ httpx>=0.27.0
|
||||
tenacity>=8.2.3
|
||||
pytz>=2024.1
|
||||
pillow>=10.3.0
|
||||
|
||||
apscheduler==3.11.0
|
||||
# Testing
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23.5
|
||||
@@ -47,4 +48,5 @@ mypy>=1.8.0
|
||||
python-jose==3.4.0
|
||||
bcrypt==4.2.1
|
||||
cryptography==44.0.1
|
||||
passlib==1.7.4
|
||||
passlib==1.7.4
|
||||
freezegun~=1.5.1
|
||||
Reference in New Issue
Block a user