diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py index edadeb3..a431bb4 100644 --- a/backend/app/auth/dependencies.py +++ b/backend/app/auth/dependencies.py @@ -15,7 +15,7 @@ async def get_current_user( db: AsyncSession = Depends(get_db) ): try: - payload = decode_token(token) # Use updated decode_token. + payload = await decode_token(token) # Use updated decode_token. user_id: str = payload.sub token_type: str = payload.type diff --git a/backend/app/auth/security.py b/backend/app/auth/security.py index f9cc59a..69f3f77 100644 --- a/backend/app/auth/security.py +++ b/backend/app/auth/security.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Optional from uuid import uuid4 @@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.database import get_db from app.schemas.token import TokenPayload, TokenResponse -from auth.utlis import is_token_revoked +from auth.utils import is_token_revoked # Configuration SECRET_KEY = settings.SECRET_KEY @@ -65,9 +65,9 @@ def create_token( to_encode = data.copy() if expires_delta: - expire = datetime.now() + expires_delta + expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.now() + ( + expire = datetime.now(timezone.utc) + ( timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access" else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) ) @@ -75,7 +75,7 @@ def create_token( to_encode.update({ "exp": expire, "type": token_type, - "iat": datetime.now(), + "iat": datetime.now(timezone.utc), }) if "jti" not in to_encode: to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added @@ -120,7 +120,16 @@ async def decode_token( """ try: # Step 1: Decode the JWT token - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + payload = jwt.decode( + token, + SECRET_KEY, + algorithms=[ALGORITHM], + options={ + "verify_exp": True, + "verify_iat": True, + "require": ["exp", "iat", "sub", "type", "jti"] + } + ) except ExpiredSignatureError: raise JWTError("Token has expired. Please refresh your token or login again.") @@ -144,7 +153,7 @@ async def decode_token( # Step 3: Validate Expiry expiration = datetime.fromtimestamp(payload["exp"]) - if datetime.now() > expiration: + if datetime.now(timezone.utc) > expiration: raise JWTError("Token has expired. Please refresh your token or login again.") # Step 4: Validate Token Type diff --git a/backend/app/auth/utils.py b/backend/app/auth/utils.py new file mode 100644 index 0000000..350b4e4 --- /dev/null +++ b/backend/app/auth/utils.py @@ -0,0 +1,45 @@ +from datetime import datetime, timezone, timedelta + +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession +from app.models.token import RevokedToken + + +async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession): + """Revoke a token by storing its `jti` in the revoked_tokens table.""" + revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id) + db.add(revoked_token) + await db.commit() + + +async def is_token_revoked(jti: str, db: AsyncSession) -> bool: + """Check whether the token's JTI is in the revoked_tokens table.""" + from sqlalchemy import select + result = await db.execute(select(RevokedToken).where(RevokedToken.jti == jti)) + revoked = result.scalar_one_or_none() + return revoked is not None + + +async def cleanup_expired_tokens(db: AsyncSession): + """Delete revoked tokens that are past their expiration time.""" + now = datetime.now(timezone.utc) + + # For access tokens (shorter expiry) + expire_before = now - timedelta(days=1) # Keep for 1 day past expiry + await db.execute( + delete(RevokedToken).where( + (RevokedToken.token_type == "access") & + (RevokedToken.created_at < expire_before) + ) + ) + + # For refresh tokens (longer expiry) + expire_before = now - timedelta(days=14) # Keep for 14 days past expiry + await db.execute( + delete(RevokedToken).where( + (RevokedToken.token_type == "refresh") & + (RevokedToken.created_at < expire_before) + ) + ) + + await db.commit() \ No newline at end of file diff --git a/backend/app/auth/utlis.py b/backend/app/auth/utlis.py deleted file mode 100644 index 3e6c6bd..0000000 --- a/backend/app/auth/utlis.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncSession -from app.models.token import RevokedToken - - -async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession): - """Revoke a token by storing its `jti` in the revoked_tokens table.""" - revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id) - db.add(revoked_token) - await db.commit() - - -async def is_token_revoked(jti: str, db: AsyncSession) -> bool: - """Check whether the token's `jti` is in the revoked_tokens table.""" - revoked = await db.get(RevokedToken, jti) - return revoked is not None diff --git a/backend/app/main.py b/backend/app/main.py index b3e03ca..697c8fc 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,3 +1,5 @@ +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.cron import CronTrigger from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse @@ -6,6 +8,11 @@ from app.core.config import settings from app.api.main import api_router import logging +from auth.utils import cleanup_expired_tokens +from app.core.database import SessionLocal + +scheduler = AsyncIOScheduler() + logger = logging.getLogger(__name__) logger.info(f"Starting app!!!") @@ -25,6 +32,26 @@ app.add_middleware( ) +# Create a function that gets its own database session +async def scheduled_cleanup(): + async with SessionLocal() as db: + await cleanup_expired_tokens(db) + +@app.on_event("startup") +async def start_scheduler(): + # Run every day at 3 AM + scheduler.add_job( + scheduled_cleanup, + CronTrigger(hour=10, minute=0), + id="token_cleanup", + name="Clean up expired revoked tokens" + ) + scheduler.start() + +@app.on_event("shutdown") +async def stop_scheduler(): + scheduler.shutdown() + @app.get("/", response_class=HTMLResponse) async def root(): return """ diff --git a/backend/requirements.txt b/backend/requirements.txt index d3da398..cbc5a8c 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,6 +4,7 @@ uvicorn>=0.34.0 pydantic>=2.10.6 pydantic-settings>=2.2.1 python-multipart>=0.0.19 +fastapi-utils==0.8.0 # Database sqlalchemy>=2.0.29 @@ -30,7 +31,7 @@ httpx>=0.27.0 tenacity>=8.2.3 pytz>=2024.1 pillow>=10.3.0 - +apscheduler==3.11.0 # Testing pytest>=8.0.0 pytest-asyncio>=0.23.5 @@ -47,4 +48,5 @@ mypy>=1.8.0 python-jose==3.4.0 bcrypt==4.2.1 cryptography==44.0.1 -passlib==1.7.4 \ No newline at end of file +passlib==1.7.4 +freezegun~=1.5.1 \ No newline at end of file