Refactor token handling and introduce token revocation logic
Updated `decode_token` for stricter validation of token claims and explicit error handling. Added utilities for token revocation and verification, improving
This commit is contained in:
@@ -4,7 +4,7 @@ from jose import JWTError, jwt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from auth.security import SECRET_KEY, ALGORITHM
|
||||
from app.auth.security import decode_token
|
||||
from app.models.user import User
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
|
||||
@@ -14,28 +14,22 @@ async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
token_type: str = payload.get("type")
|
||||
payload = decode_token(token) # Use updated decode_token.
|
||||
user_id: str = payload.sub
|
||||
token_type: str = payload.type
|
||||
|
||||
if user_id is None or token_type != "access":
|
||||
raise credentials_exception
|
||||
raise HTTPException(status_code=401, detail="Invalid token type.")
|
||||
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
user = await db.get(User, user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not found.")
|
||||
|
||||
user = await db.get(User, user_id)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
except JWTError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
|
||||
@@ -2,7 +2,8 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from black import timezone
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
from app.schemas.token import TokenPayload, TokenResponse
|
||||
@@ -75,30 +76,54 @@ def create_token(
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload:
|
||||
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
|
||||
"""
|
||||
Decode and validate a JWT token.
|
||||
Decode and validate a JWT token with explicit edge-case handling.
|
||||
|
||||
Args:
|
||||
token: The JWT token to decode
|
||||
token: The JWT token to decode.
|
||||
required_type: The expected token type (default: "access").
|
||||
|
||||
Returns:
|
||||
TokenPayload containing the decoded data
|
||||
TokenPayload containing the decoded data.
|
||||
|
||||
Raises:
|
||||
JWTError: If token is invalid or expired
|
||||
JWTError: If the token is expired, invalid, or malformed.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
# Explicitly validate required claims (`exp`, `sub`, `type`)
|
||||
if "exp" not in payload or "sub" not in payload or "type" not in payload:
|
||||
raise KeyError("Missing required claim.")
|
||||
|
||||
# Verify token expiration (`exp`)
|
||||
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
|
||||
raise ExpiredSignatureError("Token has expired.")
|
||||
|
||||
# Verify token type (`type`)
|
||||
if payload["type"] != required_type:
|
||||
expected_type = required_type
|
||||
actual_type = payload["type"]
|
||||
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
|
||||
|
||||
# Create TokenPayload object from token claims
|
||||
return TokenPayload(
|
||||
sub=payload["sub"],
|
||||
type=payload["type"],
|
||||
exp=datetime.fromtimestamp(payload["exp"]),
|
||||
iat=datetime.fromtimestamp(payload["iat"]),
|
||||
iat=datetime.fromtimestamp(payload.get("iat", 0)),
|
||||
jti=payload.get("jti")
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
raise JWTError("Malformed token. Missing required claim.") from e
|
||||
except ExpiredSignatureError as e:
|
||||
raise JWTError("Token expired. Please refresh your token to continue.") from e
|
||||
except JWTError as e:
|
||||
raise JWTError(f"Invalid token: {str(e)}")
|
||||
raise JWTError(str(e)) from e
|
||||
|
||||
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
@@ -108,4 +133,4 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new refresh token."""
|
||||
return create_token(data, expires_delta, "refresh")
|
||||
return create_token(data, expires_delta, "refresh")
|
||||
|
||||
16
backend/app/auth/utlis.py
Normal file
16
backend/app/auth/utlis.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# auth/utils.py
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.token import RevokedToken
|
||||
|
||||
|
||||
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
|
||||
"""Revoke a token by adding its `jti` to the database."""
|
||||
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
|
||||
db.add(revoked_token)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def is_token_revoked(jti: str, db: AsyncSession):
|
||||
"""Check if a token with the given `jti` is revoked."""
|
||||
result = await db.get(RevokedToken, jti)
|
||||
return result is not None
|
||||
0
backend/app/models/token.py
Normal file
0
backend/app/models/token.py
Normal file
@@ -7,3 +7,4 @@ addopts = --disable-warnings
|
||||
markers =
|
||||
sqlite: marks tests that should run on SQLite (mocked).
|
||||
postgres: marks tests that require a real PostgreSQL database.
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
@@ -21,13 +22,22 @@ def mock_user():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_success(mock_user):
|
||||
valid_token = jwt.encode({"sub": str(mock_user.id), "type": "access"}, SECRET_KEY, algorithm=ALGORITHM)
|
||||
# Create a valid access token with required claims
|
||||
valid_token = jwt.encode(
|
||||
{"sub": str(mock_user.id), "type": "access", "exp": datetime.now(tz=timezone.utc).timestamp() + 3600},
|
||||
SECRET_KEY,
|
||||
algorithm=ALGORITHM
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get.return_value = mock_user
|
||||
mock_db.get.return_value = mock_user # Ensure `db.get()` returns the mock_user
|
||||
|
||||
# Call the dependency
|
||||
user = await get_current_user(token=valid_token, db=mock_db)
|
||||
assert user == mock_user
|
||||
|
||||
# Assertions
|
||||
assert user == mock_user, "Returned user does not match the mocked user."
|
||||
mock_db.get.assert_called_once_with(User, mock_user.id)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from datetime import timedelta
|
||||
from jose import jwt, JWTError
|
||||
from datetime import timedelta, datetime, timezone
|
||||
from jose import jwt, JWTError, ExpiredSignatureError
|
||||
from app.auth.security import (
|
||||
get_password_hash, verify_password,
|
||||
create_access_token, create_refresh_token,
|
||||
@@ -52,17 +52,59 @@ def test_decode_token_expired():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_access_token({"sub": user_id}, expires_delta=timedelta(seconds=-1))
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
|
||||
assert str(exc_info.value) == "Token expired. Please refresh your token to continue."
|
||||
|
||||
|
||||
def test_decode_token_missing_exp():
|
||||
# Create a token without the `exp` claim
|
||||
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim."
|
||||
|
||||
|
||||
|
||||
def test_decode_token_missing_sub():
|
||||
# Create a token without the `sub` claim
|
||||
payload = {"exp": datetime.now().timestamp() + 60, "type": "access"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim."
|
||||
|
||||
|
||||
def test_decode_token_invalid_signature():
|
||||
token = jwt.encode({"some": "data"}, "invalid_key", algorithm=ALGORITHM)
|
||||
with pytest.raises(JWTError):
|
||||
# Use a different secret key for signing
|
||||
token = jwt.encode({"sub": "123", "type": "access"}, "wrong_secret", algorithm=ALGORITHM)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
|
||||
assert str(exc_info.value) == "Invalid token."
|
||||
|
||||
|
||||
def test_decode_token_malformed():
|
||||
malformed_token = "malformed.header.payload"
|
||||
with pytest.raises(JWTError):
|
||||
decode_token(malformed_token)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(malformed_token)
|
||||
|
||||
assert str(exc_info.value) == "Invalid token."
|
||||
|
||||
|
||||
def test_decode_token_invalid_type():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_refresh_token({"sub": user_id}) # Token type is "refresh"
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token, required_type="access") # Expecting an access token
|
||||
|
||||
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."
|
||||
|
||||
Reference in New Issue
Block a user