diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py index d62088e..edadeb3 100644 --- a/backend/app/auth/dependencies.py +++ b/backend/app/auth/dependencies.py @@ -4,7 +4,7 @@ from jose import JWTError, jwt from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_db -from auth.security import SECRET_KEY, ALGORITHM +from app.auth.security import decode_token from app.models.user import User oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") @@ -14,28 +14,22 @@ async def get_current_user( token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db) ): - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - user_id: str = payload.get("sub") - token_type: str = payload.get("type") + payload = decode_token(token) # Use updated decode_token. + user_id: str = payload.sub + token_type: str = payload.type if user_id is None or token_type != "access": - raise credentials_exception + raise HTTPException(status_code=401, detail="Invalid token type.") - except JWTError: - raise credentials_exception + user = await db.get(User, user_id) + if user is None: + raise HTTPException(status_code=401, detail="User not found.") - user = await db.get(User, user_id) - if user is None: - raise credentials_exception + return user + except JWTError as e: + raise HTTPException(status_code=401, detail=str(e)) - return user async def get_current_active_user( diff --git a/backend/app/auth/security.py b/backend/app/auth/security.py index 71fa2ce..da6ba37 100644 --- a/backend/app/auth/security.py +++ b/backend/app/auth/security.py @@ -2,7 +2,8 @@ from datetime import datetime, timedelta from typing import Optional, Tuple from uuid import uuid4 -from jose import jwt, JWTError +from black import timezone +from jose import jwt, ExpiredSignatureError, JWTError from passlib.context import CryptContext from app.core.config import settings from app.schemas.token import TokenPayload, TokenResponse @@ -75,30 +76,54 @@ def create_token( return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) -def decode_token(token: str) -> TokenPayload: +def decode_token(token: str, required_type: str = "access") -> TokenPayload: """ - Decode and validate a JWT token. + Decode and validate a JWT token with explicit edge-case handling. Args: - token: The JWT token to decode + token: The JWT token to decode. + required_type: The expected token type (default: "access"). Returns: - TokenPayload containing the decoded data + TokenPayload containing the decoded data. Raises: - JWTError: If token is invalid or expired + JWTError: If the token is expired, invalid, or malformed. """ try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + + # Explicitly validate required claims (`exp`, `sub`, `type`) + if "exp" not in payload or "sub" not in payload or "type" not in payload: + raise KeyError("Missing required claim.") + + # Verify token expiration (`exp`) + if datetime.now() > datetime.fromtimestamp(payload["exp"]): + raise ExpiredSignatureError("Token has expired.") + + # Verify token type (`type`) + if payload["type"] != required_type: + expected_type = required_type + actual_type = payload["type"] + raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.") + + # Create TokenPayload object from token claims return TokenPayload( sub=payload["sub"], type=payload["type"], exp=datetime.fromtimestamp(payload["exp"]), - iat=datetime.fromtimestamp(payload["iat"]), + iat=datetime.fromtimestamp(payload.get("iat", 0)), jti=payload.get("jti") ) + + except KeyError as e: + raise JWTError("Malformed token. Missing required claim.") from e + except ExpiredSignatureError as e: + raise JWTError("Token expired. Please refresh your token to continue.") from e except JWTError as e: - raise JWTError(f"Invalid token: {str(e)}") + raise JWTError(str(e)) from e + + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: @@ -108,4 +133,4 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) - def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """Create a new refresh token.""" - return create_token(data, expires_delta, "refresh") \ No newline at end of file + return create_token(data, expires_delta, "refresh") diff --git a/backend/app/auth/utlis.py b/backend/app/auth/utlis.py new file mode 100644 index 0000000..8d7ecd7 --- /dev/null +++ b/backend/app/auth/utlis.py @@ -0,0 +1,16 @@ +# auth/utils.py +from sqlalchemy.ext.asyncio import AsyncSession +from app.models.token import RevokedToken + + +async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession): + """Revoke a token by adding its `jti` to the database.""" + revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id) + db.add(revoked_token) + await db.commit() + + +async def is_token_revoked(jti: str, db: AsyncSession): + """Check if a token with the given `jti` is revoked.""" + result = await db.get(RevokedToken, jti) + return result is not None \ No newline at end of file diff --git a/backend/app/models/token.py b/backend/app/models/token.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/pytest.ini b/backend/pytest.ini index 225b22b..b46b024 100644 --- a/backend/pytest.ini +++ b/backend/pytest.ini @@ -7,3 +7,4 @@ addopts = --disable-warnings markers = sqlite: marks tests that should run on SQLite (mocked). postgres: marks tests that require a real PostgreSQL database. +asyncio_default_fixture_loop_scope = function diff --git a/backend/tests/auth/dependencies.py b/backend/tests/auth/dependencies.py index 771bc78..faccd6f 100644 --- a/backend/tests/auth/dependencies.py +++ b/backend/tests/auth/dependencies.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone from unittest.mock import AsyncMock import pytest @@ -21,13 +22,22 @@ def mock_user(): @pytest.mark.asyncio async def test_get_current_user_success(mock_user): - valid_token = jwt.encode({"sub": str(mock_user.id), "type": "access"}, SECRET_KEY, algorithm=ALGORITHM) + # Create a valid access token with required claims + valid_token = jwt.encode( + {"sub": str(mock_user.id), "type": "access", "exp": datetime.now(tz=timezone.utc).timestamp() + 3600}, + SECRET_KEY, + algorithm=ALGORITHM + ) + # Mock database session mock_db = AsyncMock() - mock_db.get.return_value = mock_user + mock_db.get.return_value = mock_user # Ensure `db.get()` returns the mock_user + # Call the dependency user = await get_current_user(token=valid_token, db=mock_db) - assert user == mock_user + + # Assertions + assert user == mock_user, "Returned user does not match the mocked user." mock_db.get.assert_called_once_with(User, mock_user.id) diff --git a/backend/tests/auth/test_security.py b/backend/tests/auth/test_security.py index 93317fd..95c89f5 100644 --- a/backend/tests/auth/test_security.py +++ b/backend/tests/auth/test_security.py @@ -1,6 +1,6 @@ import pytest -from datetime import timedelta -from jose import jwt, JWTError +from datetime import timedelta, datetime, timezone +from jose import jwt, JWTError, ExpiredSignatureError from app.auth.security import ( get_password_hash, verify_password, create_access_token, create_refresh_token, @@ -52,17 +52,59 @@ def test_decode_token_expired(): user_id = "123e4567-e89b-12d3-a456-426614174000" token = create_access_token({"sub": user_id}, expires_delta=timedelta(seconds=-1)) - with pytest.raises(JWTError): + with pytest.raises(JWTError) as exc_info: decode_token(token) + assert str(exc_info.value) == "Token expired. Please refresh your token to continue." + + +def test_decode_token_missing_exp(): + # Create a token without the `exp` claim + payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access"} + token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) + + with pytest.raises(JWTError) as exc_info: + decode_token(token) + + assert str(exc_info.value) == "Malformed token. Missing required claim." + + + +def test_decode_token_missing_sub(): + # Create a token without the `sub` claim + payload = {"exp": datetime.now().timestamp() + 60, "type": "access"} + token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) + + with pytest.raises(JWTError) as exc_info: + decode_token(token) + + assert str(exc_info.value) == "Malformed token. Missing required claim." + def test_decode_token_invalid_signature(): - token = jwt.encode({"some": "data"}, "invalid_key", algorithm=ALGORITHM) - with pytest.raises(JWTError): + # Use a different secret key for signing + token = jwt.encode({"sub": "123", "type": "access"}, "wrong_secret", algorithm=ALGORITHM) + + with pytest.raises(JWTError) as exc_info: decode_token(token) + assert str(exc_info.value) == "Invalid token." + def test_decode_token_malformed(): malformed_token = "malformed.header.payload" - with pytest.raises(JWTError): - decode_token(malformed_token) \ No newline at end of file + + with pytest.raises(JWTError) as exc_info: + decode_token(malformed_token) + + assert str(exc_info.value) == "Invalid token." + + +def test_decode_token_invalid_type(): + user_id = "123e4567-e89b-12d3-a456-426614174000" + token = create_refresh_token({"sub": user_id}) # Token type is "refresh" + + with pytest.raises(JWTError) as exc_info: + decode_token(token, required_type="access") # Expecting an access token + + assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."