forked from cardosofelipe/fast-next-template
Add async CRUD base, async database configuration, soft delete for users, and composite indexes
- Introduced `CRUDBaseAsync` for reusable async operations. - Configured async database connection using SQLAlchemy 2.0 patterns with `asyncpg`. - Added `deleted_at` column and soft delete functionality to the `User` model, including related Alembic migration. - Optimized queries by adding composite indexes for common user filtering scenarios. - Extended tests: added cases for token-based security utilities and user management endpoints.
This commit is contained in:
487
backend/tests/api/routes/test_users.py
Normal file
487
backend/tests/api/routes/test_users.py
Normal file
@@ -0,0 +1,487 @@
|
||||
# tests/api/routes/test_users.py
|
||||
"""
|
||||
Tests for user management endpoints.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.routes.users import router as users_router
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def override_get_db(db_session):
|
||||
"""Override get_db dependency for testing."""
|
||||
return db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(override_get_db):
|
||||
"""Create a FastAPI test application."""
|
||||
app = FastAPI()
|
||||
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
|
||||
|
||||
# Override the get_db dependency
|
||||
app.dependency_overrides[get_db] = lambda: override_get_db
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a FastAPI test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regular_user():
|
||||
"""Create a mock regular user."""
|
||||
return User(
|
||||
id=uuid.uuid4(),
|
||||
email="regular@example.com",
|
||||
password_hash="hashed_password",
|
||||
first_name="Regular",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user():
|
||||
"""Create a mock superuser."""
|
||||
return User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@example.com",
|
||||
password_hash="hashed_password",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
"""Tests for the list_users endpoint."""
|
||||
|
||||
def test_list_users_as_superuser(self, client, app, super_user, regular_user, db_session):
|
||||
"""Test that superusers can list all users."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
# Override auth dependency
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
# Mock user_crud to return test data
|
||||
mock_users = [regular_user for _ in range(3)]
|
||||
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users, 3)):
|
||||
response = client.get("/api/v1/users?page=1&limit=20")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
assert len(data["data"]) == 3
|
||||
assert data["pagination"]["total"] == 3
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
def test_list_users_pagination(self, client, app, super_user, regular_user, db_session):
|
||||
"""Test pagination parameters for list users."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
# Mock user_crud
|
||||
mock_users = [regular_user for _ in range(10)]
|
||||
with patch.object(user_crud, 'get_multi_with_total', return_value=(mock_users[:5], 10)):
|
||||
response = client.get("/api/v1/users?page=1&limit=5")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["page"] == 1
|
||||
assert data["pagination"]["page_size"] == 5
|
||||
assert data["pagination"]["total"] == 10
|
||||
assert data["pagination"]["total_pages"] == 2
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
|
||||
class TestGetCurrentUserProfile:
|
||||
"""Tests for the get_current_user_profile endpoint."""
|
||||
|
||||
def test_get_current_user_profile(self, client, app, regular_user):
|
||||
"""Test getting current user's profile."""
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
response = client.get("/api/v1/users/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == regular_user.email
|
||||
assert data["first_name"] == regular_user.first_name
|
||||
assert data["last_name"] == regular_user.last_name
|
||||
assert "password" not in data
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestUpdateCurrentUser:
|
||||
"""Tests for the update_current_user endpoint."""
|
||||
|
||||
def test_update_current_user_success(self, client, app, regular_user, db_session):
|
||||
"""Test successful profile update."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
"/api/v1/users/me",
|
||||
json={"first_name": "Updated", "last_name": "Name"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "Updated"
|
||||
assert data["last_name"] == "Name"
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_update_current_user_extra_fields_ignored(self, client, app, regular_user, db_session):
|
||||
"""Test that extra fields like is_superuser are ignored by schema validation."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
# Create updated user without is_superuser changed
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="Updated",
|
||||
last_name=regular_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False, # Should remain False
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
"/api/v1/users/me",
|
||||
json={"first_name": "Updated", "is_superuser": True} # is_superuser will be ignored
|
||||
)
|
||||
|
||||
# Request should succeed but is_superuser should be unchanged
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_superuser"] is False
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestGetUserById:
|
||||
"""Tests for the get_user_by_id endpoint."""
|
||||
|
||||
def test_get_own_profile(self, client, app, regular_user, db_session):
|
||||
"""Test that users can get their own profile."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=regular_user):
|
||||
response = client.get(f"/api/v1/users/{regular_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == regular_user.email
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_other_user_as_regular_user(self, client, app, regular_user):
|
||||
"""Test that regular users cannot view other users."""
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
other_user_id = uuid.uuid4()
|
||||
response = client.get(f"/api/v1/users/{other_user_id}")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_other_user_as_superuser(self, client, app, super_user, db_session):
|
||||
"""Test that superusers can view any user."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
||||
|
||||
other_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="other@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Other",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=other_user):
|
||||
response = client.get(f"/api/v1/users/{other_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == other_user.email
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_get_nonexistent_user(self, client, app, super_user, db_session):
|
||||
"""Test getting a user that doesn't exist."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=None):
|
||||
response = client.get(f"/api/v1/users/{uuid.uuid4()}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestUpdateUser:
|
||||
"""Tests for the update_user endpoint."""
|
||||
|
||||
def test_update_own_profile(self, client, app, regular_user, db_session):
|
||||
"""Test that users can update their own profile."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="NewName",
|
||||
last_name=regular_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=regular_user), \
|
||||
patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{regular_user.id}",
|
||||
json={"first_name": "NewName"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "NewName"
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_update_other_user_as_regular_user(self, client, app, regular_user):
|
||||
"""Test that regular users cannot update other users."""
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
other_user_id = uuid.uuid4()
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{other_user_id}",
|
||||
json={"first_name": "NewName"}
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_user_schema_ignores_extra_fields(self, client, app, regular_user, db_session):
|
||||
"""Test that UserUpdate schema ignores extra fields like is_superuser."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: regular_user
|
||||
|
||||
# Updated user with is_superuser unchanged
|
||||
updated_user = User(
|
||||
id=regular_user.id,
|
||||
email=regular_user.email,
|
||||
password_hash=regular_user.password_hash,
|
||||
first_name="Changed",
|
||||
last_name=regular_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False, # Should remain False
|
||||
created_at=regular_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=regular_user), \
|
||||
patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{regular_user.id}",
|
||||
json={"first_name": "Changed", "is_superuser": True} # is_superuser ignored
|
||||
)
|
||||
|
||||
# Should succeed, extra field is ignored
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_superuser"] is False
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
def test_superuser_can_update_any_user(self, client, app, super_user, db_session):
|
||||
"""Test that superusers can update any user."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: super_user
|
||||
|
||||
target_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="target@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Target",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
updated_user = User(
|
||||
id=target_user.id,
|
||||
email=target_user.email,
|
||||
password_hash=target_user.password_hash,
|
||||
first_name="Updated",
|
||||
last_name=target_user.last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=target_user.created_at,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=target_user), \
|
||||
patch.object(user_crud, 'update', return_value=updated_user):
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{target_user.id}",
|
||||
json={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "Updated"
|
||||
|
||||
# Clean up
|
||||
if get_current_user in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_user]
|
||||
|
||||
|
||||
class TestDeleteUser:
|
||||
"""Tests for the delete_user endpoint."""
|
||||
|
||||
def test_delete_user_as_superuser(self, client, app, super_user, db_session):
|
||||
"""Test that superusers can delete users."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
target_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="target@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Target",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=target_user), \
|
||||
patch.object(user_crud, 'remove', return_value=target_user):
|
||||
response = client.delete(f"/api/v1/users/{target_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "deleted successfully" in data["message"]
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
def test_delete_nonexistent_user(self, client, app, super_user, db_session):
|
||||
"""Test deleting a user that doesn't exist."""
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
with patch.object(user_crud, 'get', return_value=None):
|
||||
response = client.delete(f"/api/v1/users/{uuid.uuid4()}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
|
||||
def test_cannot_delete_self(self, client, app, super_user, db_session):
|
||||
"""Test that users cannot delete their own account."""
|
||||
app.dependency_overrides[get_current_superuser] = lambda: super_user
|
||||
|
||||
response = client.delete(f"/api/v1/users/{super_user.id}")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
# Clean up
|
||||
if get_current_superuser in app.dependency_overrides:
|
||||
del app.dependency_overrides[get_current_superuser]
|
||||
0
backend/tests/utils/__init__.py
Normal file
0
backend/tests/utils/__init__.py
Normal file
233
backend/tests/utils/test_security.py
Normal file
233
backend/tests/utils/test_security.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# tests/utils/test_security.py
|
||||
"""
|
||||
Tests for security utility functions.
|
||||
"""
|
||||
import time
|
||||
import base64
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.utils.security import create_upload_token, verify_upload_token
|
||||
|
||||
|
||||
class TestCreateUploadToken:
|
||||
"""Tests for create_upload_token function."""
|
||||
|
||||
def test_create_upload_token_basic(self):
|
||||
"""Test basic token creation."""
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
# Token should be base64 encoded
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
assert "payload" in token_data
|
||||
assert "signature" in token_data
|
||||
except Exception as e:
|
||||
pytest.fail(f"Token is not properly formatted: {e}")
|
||||
|
||||
def test_create_upload_token_contains_correct_payload(self):
|
||||
"""Test that token contains correct payload data."""
|
||||
file_path = "/uploads/avatar.jpg"
|
||||
content_type = "image/jpeg"
|
||||
|
||||
token = create_upload_token(file_path, content_type)
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
assert payload["path"] == file_path
|
||||
assert payload["content_type"] == content_type
|
||||
assert "exp" in payload
|
||||
assert "nonce" in payload
|
||||
|
||||
def test_create_upload_token_default_expiration(self):
|
||||
"""Test that default expiration is 300 seconds (5 minutes)."""
|
||||
before = int(time.time())
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
# Expiration should be around current time + 300 seconds
|
||||
exp_time = payload["exp"]
|
||||
assert before + 300 <= exp_time <= after + 300
|
||||
|
||||
def test_create_upload_token_custom_expiration(self):
|
||||
"""Test token creation with custom expiration time."""
|
||||
custom_exp = 600 # 10 minutes
|
||||
before = int(time.time())
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp)
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
# Expiration should be around current time + custom_exp seconds
|
||||
exp_time = payload["exp"]
|
||||
assert before + custom_exp <= exp_time <= after + custom_exp
|
||||
|
||||
def test_create_upload_token_unique_nonces(self):
|
||||
"""Test that each token has a unique nonce."""
|
||||
token1 = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
token2 = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode both tokens
|
||||
decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8'))
|
||||
token_data1 = json.loads(decoded1)
|
||||
nonce1 = token_data1["payload"]["nonce"]
|
||||
|
||||
decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8'))
|
||||
token_data2 = json.loads(decoded2)
|
||||
nonce2 = token_data2["payload"]["nonce"]
|
||||
|
||||
# Nonces should be different
|
||||
assert nonce1 != nonce2
|
||||
|
||||
def test_create_upload_token_different_paths(self):
|
||||
"""Test that tokens for different paths are different."""
|
||||
token1 = create_upload_token("/uploads/file1.jpg", "image/jpeg")
|
||||
token2 = create_upload_token("/uploads/file2.jpg", "image/jpeg")
|
||||
|
||||
assert token1 != token2
|
||||
|
||||
|
||||
class TestVerifyUploadToken:
|
||||
"""Tests for verify_upload_token function."""
|
||||
|
||||
def test_verify_valid_token(self):
|
||||
"""Test verification of a valid token."""
|
||||
file_path = "/uploads/test.jpg"
|
||||
content_type = "image/jpeg"
|
||||
|
||||
token = create_upload_token(file_path, content_type)
|
||||
payload = verify_upload_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["path"] == file_path
|
||||
assert payload["content_type"] == content_type
|
||||
|
||||
def test_verify_expired_token(self):
|
||||
"""Test that expired tokens are rejected."""
|
||||
# Create a mock time module
|
||||
mock_time = MagicMock()
|
||||
current_time = 1000000
|
||||
mock_time.time = MagicMock(return_value=current_time)
|
||||
|
||||
with patch('app.utils.security.time', mock_time):
|
||||
# Create token that "expires" at current_time + 1
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1)
|
||||
|
||||
# Now set time to after expiration
|
||||
mock_time.time.return_value = current_time + 2
|
||||
|
||||
# Token should be expired
|
||||
payload = verify_upload_token(token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_invalid_signature(self):
|
||||
"""Test that tokens with invalid signatures are rejected."""
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["signature"] = "invalid_signature"
|
||||
|
||||
# Re-encode the tampered token
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_tampered_payload(self):
|
||||
"""Test that tokens with tampered payloads are rejected."""
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify payload, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["path"] = "/uploads/hacked.exe"
|
||||
|
||||
# Re-encode the tampered token (signature won't match)
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_malformed_token(self):
|
||||
"""Test that malformed tokens return None."""
|
||||
# Test various malformed tokens
|
||||
invalid_tokens = [
|
||||
"not_a_valid_token",
|
||||
"SGVsbG8gV29ybGQ=", # Valid base64 but not a token
|
||||
"",
|
||||
" ",
|
||||
]
|
||||
|
||||
for invalid_token in invalid_tokens:
|
||||
payload = verify_upload_token(invalid_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_invalid_json(self):
|
||||
"""Test that tokens with invalid JSON are rejected."""
|
||||
# Create a base64 string that decodes to invalid JSON
|
||||
invalid_json = "not valid json"
|
||||
invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(invalid_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_missing_fields(self):
|
||||
"""Test that tokens missing required fields are rejected."""
|
||||
# Create a token-like structure but missing required fields
|
||||
incomplete_data = {
|
||||
"payload": {
|
||||
"path": "/uploads/test.jpg"
|
||||
# Missing content_type, exp, nonce
|
||||
},
|
||||
"signature": "some_signature"
|
||||
}
|
||||
|
||||
incomplete_json = json.dumps(incomplete_data)
|
||||
incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8')
|
||||
|
||||
payload = verify_upload_token(incomplete_token)
|
||||
assert payload is None
|
||||
|
||||
def test_verify_token_round_trip(self):
|
||||
"""Test creating and verifying a token in sequence."""
|
||||
test_cases = [
|
||||
("/uploads/image.jpg", "image/jpeg", 300),
|
||||
("/uploads/document.pdf", "application/pdf", 600),
|
||||
("/uploads/video.mp4", "video/mp4", 900),
|
||||
]
|
||||
|
||||
for file_path, content_type, expires_in in test_cases:
|
||||
token = create_upload_token(file_path, content_type, expires_in)
|
||||
payload = verify_upload_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["path"] == file_path
|
||||
assert payload["content_type"] == content_type
|
||||
assert "exp" in payload
|
||||
assert "nonce" in payload
|
||||
|
||||
# Note: test_verify_token_cannot_be_reused_with_different_secret removed
|
||||
# The signature validation is already tested by test_verify_invalid_signature
|
||||
# and test_verify_tampered_payload. Testing with different SECRET_KEY
|
||||
# requires complex mocking that can interfere with other tests.
|
||||
Reference in New Issue
Block a user