forked from cardosofelipe/fast-next-template
feat(tests): add unit tests for Context Management API routes
- Added detailed unit tests for `/context` endpoints, covering health checks, context assembly, token counting, budget retrieval, and cache invalidation. - Included edge cases, error handling, and input validation for context-related operations. - Improved test coverage for the Context Management module with mocked dependencies and integration scenarios.
This commit is contained in:
466
backend/tests/api/routes/test_context.py
Normal file
466
backend/tests/api/routes/test_context.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
Tests for Context Management API Routes.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.services.context import (
|
||||
AssembledContext,
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
ContextEngine,
|
||||
TokenBudget,
|
||||
)
|
||||
from app.services.mcp import MCPClientManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_client():
|
||||
"""Create a mock MCP client manager."""
|
||||
client = MagicMock(spec=MCPClientManager)
|
||||
client.is_initialized = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context_engine(mock_mcp_client):
|
||||
"""Create a mock ContextEngine."""
|
||||
engine = MagicMock(spec=ContextEngine)
|
||||
engine._mcp = mock_mcp_client
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_superuser():
|
||||
"""Create a mock superuser."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = "00000000-0000-0000-0000-000000000001"
|
||||
user.is_superuser = True
|
||||
user.email = "admin@example.com"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_mcp_client, mock_context_engine, mock_superuser):
|
||||
"""Create a FastAPI test client with mocked dependencies."""
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.api.routes.context import get_context_engine
|
||||
from app.services.mcp import get_mcp_client
|
||||
|
||||
# Override dependencies
|
||||
async def override_get_mcp_client():
|
||||
return mock_mcp_client
|
||||
|
||||
async def override_get_context_engine():
|
||||
return mock_context_engine
|
||||
|
||||
async def override_require_superuser():
|
||||
return mock_superuser
|
||||
|
||||
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
|
||||
app.dependency_overrides[get_context_engine] = override_get_context_engine
|
||||
app.dependency_overrides[require_superuser] = override_require_superuser
|
||||
|
||||
with patch("app.main.check_database_health", return_value=True):
|
||||
yield TestClient(app)
|
||||
|
||||
# Clean up
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestContextHealth:
|
||||
"""Tests for GET /context/health endpoint."""
|
||||
|
||||
def test_health_check_success(self, client, mock_context_engine, mock_mcp_client):
|
||||
"""Test context engine health check."""
|
||||
mock_context_engine.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"cache": {"hits": 10, "misses": 5},
|
||||
"settings": {"cache_enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "mcp_connected" in data
|
||||
assert "cache_enabled" in data
|
||||
|
||||
|
||||
class TestAssembleContext:
|
||||
"""Tests for POST /context/assemble endpoint."""
|
||||
|
||||
def test_assemble_context_success(self, client, mock_context_engine):
|
||||
"""Test successful context assembly."""
|
||||
# Create mock assembled context
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Assembled context content"
|
||||
mock_result.total_tokens = 500
|
||||
mock_result.context_count = 2
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 50.5
|
||||
mock_result.metadata = {}
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "What is the auth flow?",
|
||||
"model": "claude-3-sonnet",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["content"] == "Assembled context content"
|
||||
assert data["total_tokens"] == 500
|
||||
assert data["context_count"] == 2
|
||||
assert data["compressed"] is False
|
||||
assert "budget_used_percent" in data
|
||||
|
||||
def test_assemble_context_with_conversation(self, client, mock_context_engine):
|
||||
"""Test context assembly with conversation history."""
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Context with history"
|
||||
mock_result.total_tokens = 800
|
||||
mock_result.context_count = 1
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 30.0
|
||||
mock_result.metadata = {}
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "Continue the discussion",
|
||||
"conversation_history": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
call_args = mock_context_engine.assemble_context.call_args
|
||||
assert call_args.kwargs["conversation_history"] == [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
def test_assemble_context_with_tool_results(self, client, mock_context_engine):
|
||||
"""Test context assembly with tool results."""
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Context with tools"
|
||||
mock_result.total_tokens = 600
|
||||
mock_result.context_count = 1
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 25.0
|
||||
mock_result.metadata = {}
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "What did the search find?",
|
||||
"tool_results": [
|
||||
{
|
||||
"tool_name": "search_knowledge",
|
||||
"content": {"results": ["item1", "item2"]},
|
||||
"status": "success",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
call_args = mock_context_engine.assemble_context.call_args
|
||||
assert len(call_args.kwargs["tool_results"]) == 1
|
||||
|
||||
def test_assemble_context_timeout(self, client, mock_context_engine):
|
||||
"""Test context assembly timeout error."""
|
||||
mock_context_engine.assemble_context = AsyncMock(
|
||||
side_effect=AssemblyTimeoutError("Assembly exceeded 5000ms limit")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
|
||||
|
||||
def test_assemble_context_budget_exceeded(self, client, mock_context_engine):
|
||||
"""Test context assembly budget exceeded error."""
|
||||
mock_context_engine.assemble_context = AsyncMock(
|
||||
side_effect=BudgetExceededError("Token budget exceeded: 5000 > 4000")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
||||
|
||||
def test_assemble_context_validation_error(self, client):
|
||||
"""Test context assembly with invalid request."""
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={}, # Missing required fields
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestCountTokens:
|
||||
"""Tests for POST /context/count-tokens endpoint."""
|
||||
|
||||
def test_count_tokens_success(self, client, mock_context_engine):
|
||||
"""Test successful token counting."""
|
||||
mock_context_engine.count_tokens = AsyncMock(return_value=42)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={
|
||||
"content": "This is some test content.",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["token_count"] == 42
|
||||
assert data["model"] == "claude-3-sonnet"
|
||||
|
||||
def test_count_tokens_without_model(self, client, mock_context_engine):
|
||||
"""Test token counting without specifying model."""
|
||||
mock_context_engine.count_tokens = AsyncMock(return_value=100)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={"content": "Some content to count."},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["token_count"] == 100
|
||||
assert data["model"] is None
|
||||
|
||||
|
||||
class TestGetBudget:
|
||||
"""Tests for GET /context/budget/{model} endpoint."""
|
||||
|
||||
def test_get_budget_success(self, client, mock_context_engine):
|
||||
"""Test getting token budget for a model."""
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=100000,
|
||||
system=10000,
|
||||
knowledge=40000,
|
||||
conversation=30000,
|
||||
tools=10000,
|
||||
response_reserve=10000,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/budget/claude-3-opus")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["model"] == "claude-3-opus"
|
||||
assert data["total_tokens"] == 100000
|
||||
assert data["system_tokens"] == 10000
|
||||
assert data["knowledge_tokens"] == 40000
|
||||
|
||||
def test_get_budget_with_max_tokens(self, client, mock_context_engine):
|
||||
"""Test getting budget with custom max tokens."""
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=2000,
|
||||
system=200,
|
||||
knowledge=800,
|
||||
conversation=600,
|
||||
tools=200,
|
||||
response_reserve=200,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/budget/gpt-4?max_tokens=2000")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total_tokens"] == 2000
|
||||
|
||||
|
||||
class TestGetStats:
|
||||
"""Tests for GET /context/stats endpoint."""
|
||||
|
||||
def test_get_stats_success(self, client, mock_context_engine):
|
||||
"""Test getting engine statistics."""
|
||||
mock_context_engine.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"cache": {
|
||||
"hits": 100,
|
||||
"misses": 25,
|
||||
"hit_rate": 0.8,
|
||||
},
|
||||
"settings": {
|
||||
"compression_threshold": 0.9,
|
||||
"max_assembly_time_ms": 5000,
|
||||
"cache_enabled": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/stats")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["cache"]["hits"] == 100
|
||||
assert data["settings"]["cache_enabled"] is True
|
||||
|
||||
|
||||
class TestInvalidateCache:
|
||||
"""Tests for POST /context/cache/invalidate endpoint."""
|
||||
|
||||
def test_invalidate_cache_by_project(self, client, mock_context_engine):
|
||||
"""Test cache invalidation by project ID."""
|
||||
mock_context_engine.invalidate_cache = AsyncMock(return_value=5)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/cache/invalidate?project_id=test-project"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
mock_context_engine.invalidate_cache.assert_called_once()
|
||||
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "test-project"
|
||||
|
||||
def test_invalidate_cache_by_pattern(self, client, mock_context_engine):
|
||||
"""Test cache invalidation by pattern."""
|
||||
mock_context_engine.invalidate_cache = AsyncMock(return_value=10)
|
||||
|
||||
response = client.post("/api/v1/context/cache/invalidate?pattern=*auth*")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
mock_context_engine.invalidate_cache.assert_called_once()
|
||||
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
|
||||
assert call_kwargs["pattern"] == "*auth*"
|
||||
|
||||
def test_invalidate_cache_all(self, client, mock_context_engine):
|
||||
"""Test invalidating all cache entries."""
|
||||
mock_context_engine.invalidate_cache = AsyncMock(return_value=100)
|
||||
|
||||
response = client.post("/api/v1/context/cache/invalidate")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
|
||||
class TestContextEndpointsEdgeCases:
|
||||
"""Edge case tests for Context endpoints."""
|
||||
|
||||
def test_context_content_type(self, client, mock_context_engine):
|
||||
"""Test that endpoints return JSON content type."""
|
||||
mock_context_engine.get_stats = AsyncMock(
|
||||
return_value={"cache": {}, "settings": {}}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/health")
|
||||
|
||||
assert "application/json" in response.headers["content-type"]
|
||||
|
||||
def test_assemble_context_with_knowledge_query(self, client, mock_context_engine):
|
||||
"""Test context assembly with knowledge base query."""
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Context with knowledge"
|
||||
mock_result.total_tokens = 1000
|
||||
mock_result.context_count = 3
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 100.0
|
||||
mock_result.metadata = {
|
||||
"compressed_contexts": 1
|
||||
} # Indicates compression happened
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "How does authentication work?",
|
||||
"knowledge_query": "authentication flow implementation",
|
||||
"knowledge_limit": 5,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
call_kwargs = mock_context_engine.assemble_context.call_args.kwargs
|
||||
assert call_kwargs["knowledge_query"] == "authentication flow implementation"
|
||||
assert call_kwargs["knowledge_limit"] == 5
|
||||
Reference in New Issue
Block a user