from datetime import timedelta, datetime from unittest.mock import AsyncMock import pytest from jose import jwt, JWTError from sqlalchemy.ext.asyncio import AsyncSession from app.auth.security import ( get_password_hash, verify_password, create_access_token, create_refresh_token, decode_token, SECRET_KEY, ALGORITHM ) from app.schemas.token import TokenPayload def test_password_hashing(): plain_password = "securepassword123" hashed_password = get_password_hash(plain_password) # Ensure hashed passwords are not the same assert hashed_password != plain_password # Test password verification assert verify_password(plain_password, hashed_password) assert not verify_password("wrongpassword", hashed_password) def test_access_token_creation(): user_id = "123e4567-e89b-12d3-a456-426614174000" token = create_access_token({"sub": user_id}) decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) assert decoded_payload.get("sub") == user_id assert decoded_payload.get("type") == "access" def test_refresh_token_creation(): user_id = "123e4567-e89b-12d3-a456-426614174000" token = create_refresh_token({"sub": user_id}) decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) assert decoded_payload.get("sub") == user_id assert decoded_payload.get("type") == "refresh" @pytest.mark.asyncio async def test_decode_token_valid(): user_id = "123e4567-e89b-12d3-a456-426614174000" access_token = create_access_token({"sub": user_id, "jti": "test-jti"}) # Mock is_token_revoked to return False mock_db = AsyncMock(spec=AsyncSession) mock_db.get = AsyncMock(return_value=None) token_payload = await decode_token(access_token, db=mock_db) assert isinstance(token_payload, TokenPayload) assert token_payload.sub == user_id assert token_payload.type == "access" @pytest.mark.asyncio async def test_decode_token_expired(): user_id = "123e4567-e89b-12d3-a456-426614174000" token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1)) # Mock database mock_db = AsyncMock(spec=AsyncSession) with pytest.raises(JWTError) as exc_info: await decode_token(token, db=mock_db) assert str(exc_info.value) == "Token has been revoked." @pytest.mark.asyncio async def test_decode_token_missing_exp(): # Create a token without the `exp` claim payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"} token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) # Mock database mock_db = AsyncMock(spec=AsyncSession) with pytest.raises(JWTError) as exc_info: await decode_token(token, db=mock_db) assert str(exc_info.value) == "Malformed token. Missing required claim(s)." @pytest.mark.asyncio async def test_decode_token_missing_sub(): # Create a token without the `sub` claim payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"} token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) # Mock database mock_db = AsyncMock(spec=AsyncSession) with pytest.raises(JWTError) as exc_info: await decode_token(token, db=mock_db) assert str(exc_info.value) == "Malformed token. Missing required claim(s)." @pytest.mark.asyncio async def test_decode_token_invalid_signature(): # Use a different secret key for signing token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM) # Mock database mock_db = AsyncMock(spec=AsyncSession) with pytest.raises(JWTError) as exc_info: await decode_token(token, db=mock_db) assert str(exc_info.value) == "Signature verification failed." @pytest.mark.asyncio async def test_decode_token_malformed(): malformed_token = "malformed.header.payload" # Mock database mock_db = AsyncMock(spec=AsyncSession) with pytest.raises(JWTError) as exc_info: await decode_token(malformed_token, db=mock_db) assert str(exc_info.value) == "Invalid token." @pytest.mark.asyncio async def test_decode_token_invalid_type(): user_id = "123e4567-e89b-12d3-a456-426614174000" token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh" # Mock database mock_db = AsyncMock(spec=AsyncSession) with pytest.raises(JWTError) as exc_info: await decode_token(token, required_type="access", db=mock_db) # Expecting an access token assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."