feat(tests): add unit tests for Context Management API routes

- Added detailed unit tests for `/context` endpoints, covering health checks, context assembly, token counting, budget retrieval, and cache invalidation.
- Included edge cases, error handling, and input validation for context-related operations.
- Improved test coverage for the Context Management module with mocked dependencies and integration scenarios.
This commit is contained in:
2026-01-05 01:02:49 +01:00
parent ad0c06851d
commit 4b149b8a52

View File

@@ -0,0 +1,466 @@
"""
Tests for Context Management API Routes.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from app.main import app
from app.models.user import User
from app.services.context import (
AssembledContext,
AssemblyTimeoutError,
BudgetExceededError,
ContextEngine,
TokenBudget,
)
from app.services.mcp import MCPClientManager
@pytest.fixture
def mock_mcp_client():
"""Create a mock MCP client manager."""
client = MagicMock(spec=MCPClientManager)
client.is_initialized = True
return client
@pytest.fixture
def mock_context_engine(mock_mcp_client):
"""Create a mock ContextEngine."""
engine = MagicMock(spec=ContextEngine)
engine._mcp = mock_mcp_client
return engine
@pytest.fixture
def mock_superuser():
"""Create a mock superuser."""
user = MagicMock(spec=User)
user.id = "00000000-0000-0000-0000-000000000001"
user.is_superuser = True
user.email = "admin@example.com"
return user
@pytest.fixture
def client(mock_mcp_client, mock_context_engine, mock_superuser):
"""Create a FastAPI test client with mocked dependencies."""
from app.api.dependencies.permissions import require_superuser
from app.api.routes.context import get_context_engine
from app.services.mcp import get_mcp_client
# Override dependencies
async def override_get_mcp_client():
return mock_mcp_client
async def override_get_context_engine():
return mock_context_engine
async def override_require_superuser():
return mock_superuser
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
app.dependency_overrides[get_context_engine] = override_get_context_engine
app.dependency_overrides[require_superuser] = override_require_superuser
with patch("app.main.check_database_health", return_value=True):
yield TestClient(app)
# Clean up
app.dependency_overrides.clear()
class TestContextHealth:
"""Tests for GET /context/health endpoint."""
def test_health_check_success(self, client, mock_context_engine, mock_mcp_client):
"""Test context engine health check."""
mock_context_engine.get_stats = AsyncMock(
return_value={
"cache": {"hits": 10, "misses": 5},
"settings": {"cache_enabled": True},
}
)
response = client.get("/api/v1/context/health")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "healthy"
assert "mcp_connected" in data
assert "cache_enabled" in data
class TestAssembleContext:
"""Tests for POST /context/assemble endpoint."""
def test_assemble_context_success(self, client, mock_context_engine):
"""Test successful context assembly."""
# Create mock assembled context
mock_result = MagicMock(spec=AssembledContext)
mock_result.content = "Assembled context content"
mock_result.total_tokens = 500
mock_result.context_count = 2
mock_result.excluded_count = 0
mock_result.assembly_time_ms = 50.5
mock_result.metadata = {}
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=4000,
system=500,
knowledge=1500,
conversation=1000,
tools=500,
response_reserve=500,
)
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "What is the auth flow?",
"model": "claude-3-sonnet",
"system_prompt": "You are a helpful assistant.",
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["content"] == "Assembled context content"
assert data["total_tokens"] == 500
assert data["context_count"] == 2
assert data["compressed"] is False
assert "budget_used_percent" in data
def test_assemble_context_with_conversation(self, client, mock_context_engine):
"""Test context assembly with conversation history."""
mock_result = MagicMock(spec=AssembledContext)
mock_result.content = "Context with history"
mock_result.total_tokens = 800
mock_result.context_count = 1
mock_result.excluded_count = 0
mock_result.assembly_time_ms = 30.0
mock_result.metadata = {}
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=4000,
system=500,
knowledge=1500,
conversation=1000,
tools=500,
response_reserve=500,
)
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "Continue the discussion",
"conversation_history": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
],
},
)
assert response.status_code == status.HTTP_200_OK
call_args = mock_context_engine.assemble_context.call_args
assert call_args.kwargs["conversation_history"] == [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
def test_assemble_context_with_tool_results(self, client, mock_context_engine):
"""Test context assembly with tool results."""
mock_result = MagicMock(spec=AssembledContext)
mock_result.content = "Context with tools"
mock_result.total_tokens = 600
mock_result.context_count = 1
mock_result.excluded_count = 0
mock_result.assembly_time_ms = 25.0
mock_result.metadata = {}
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=4000,
system=500,
knowledge=1500,
conversation=1000,
tools=500,
response_reserve=500,
)
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "What did the search find?",
"tool_results": [
{
"tool_name": "search_knowledge",
"content": {"results": ["item1", "item2"]},
"status": "success",
}
],
},
)
assert response.status_code == status.HTTP_200_OK
call_args = mock_context_engine.assemble_context.call_args
assert len(call_args.kwargs["tool_results"]) == 1
def test_assemble_context_timeout(self, client, mock_context_engine):
"""Test context assembly timeout error."""
mock_context_engine.assemble_context = AsyncMock(
side_effect=AssemblyTimeoutError("Assembly exceeded 5000ms limit")
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "test",
},
)
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
def test_assemble_context_budget_exceeded(self, client, mock_context_engine):
"""Test context assembly budget exceeded error."""
mock_context_engine.assemble_context = AsyncMock(
side_effect=BudgetExceededError("Token budget exceeded: 5000 > 4000")
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "test",
},
)
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
def test_assemble_context_validation_error(self, client):
"""Test context assembly with invalid request."""
response = client.post(
"/api/v1/context/assemble",
json={}, # Missing required fields
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
class TestCountTokens:
"""Tests for POST /context/count-tokens endpoint."""
def test_count_tokens_success(self, client, mock_context_engine):
"""Test successful token counting."""
mock_context_engine.count_tokens = AsyncMock(return_value=42)
response = client.post(
"/api/v1/context/count-tokens",
json={
"content": "This is some test content.",
"model": "claude-3-sonnet",
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["token_count"] == 42
assert data["model"] == "claude-3-sonnet"
def test_count_tokens_without_model(self, client, mock_context_engine):
"""Test token counting without specifying model."""
mock_context_engine.count_tokens = AsyncMock(return_value=100)
response = client.post(
"/api/v1/context/count-tokens",
json={"content": "Some content to count."},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["token_count"] == 100
assert data["model"] is None
class TestGetBudget:
"""Tests for GET /context/budget/{model} endpoint."""
def test_get_budget_success(self, client, mock_context_engine):
"""Test getting token budget for a model."""
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=100000,
system=10000,
knowledge=40000,
conversation=30000,
tools=10000,
response_reserve=10000,
)
)
response = client.get("/api/v1/context/budget/claude-3-opus")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["model"] == "claude-3-opus"
assert data["total_tokens"] == 100000
assert data["system_tokens"] == 10000
assert data["knowledge_tokens"] == 40000
def test_get_budget_with_max_tokens(self, client, mock_context_engine):
"""Test getting budget with custom max tokens."""
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=2000,
system=200,
knowledge=800,
conversation=600,
tools=200,
response_reserve=200,
)
)
response = client.get("/api/v1/context/budget/gpt-4?max_tokens=2000")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total_tokens"] == 2000
class TestGetStats:
"""Tests for GET /context/stats endpoint."""
def test_get_stats_success(self, client, mock_context_engine):
"""Test getting engine statistics."""
mock_context_engine.get_stats = AsyncMock(
return_value={
"cache": {
"hits": 100,
"misses": 25,
"hit_rate": 0.8,
},
"settings": {
"compression_threshold": 0.9,
"max_assembly_time_ms": 5000,
"cache_enabled": True,
},
}
)
response = client.get("/api/v1/context/stats")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["cache"]["hits"] == 100
assert data["settings"]["cache_enabled"] is True
class TestInvalidateCache:
"""Tests for POST /context/cache/invalidate endpoint."""
def test_invalidate_cache_by_project(self, client, mock_context_engine):
"""Test cache invalidation by project ID."""
mock_context_engine.invalidate_cache = AsyncMock(return_value=5)
response = client.post(
"/api/v1/context/cache/invalidate?project_id=test-project"
)
assert response.status_code == status.HTTP_204_NO_CONTENT
mock_context_engine.invalidate_cache.assert_called_once()
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
assert call_kwargs["project_id"] == "test-project"
def test_invalidate_cache_by_pattern(self, client, mock_context_engine):
"""Test cache invalidation by pattern."""
mock_context_engine.invalidate_cache = AsyncMock(return_value=10)
response = client.post("/api/v1/context/cache/invalidate?pattern=*auth*")
assert response.status_code == status.HTTP_204_NO_CONTENT
mock_context_engine.invalidate_cache.assert_called_once()
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
assert call_kwargs["pattern"] == "*auth*"
def test_invalidate_cache_all(self, client, mock_context_engine):
"""Test invalidating all cache entries."""
mock_context_engine.invalidate_cache = AsyncMock(return_value=100)
response = client.post("/api/v1/context/cache/invalidate")
assert response.status_code == status.HTTP_204_NO_CONTENT
class TestContextEndpointsEdgeCases:
"""Edge case tests for Context endpoints."""
def test_context_content_type(self, client, mock_context_engine):
"""Test that endpoints return JSON content type."""
mock_context_engine.get_stats = AsyncMock(
return_value={"cache": {}, "settings": {}}
)
response = client.get("/api/v1/context/health")
assert "application/json" in response.headers["content-type"]
def test_assemble_context_with_knowledge_query(self, client, mock_context_engine):
"""Test context assembly with knowledge base query."""
mock_result = MagicMock(spec=AssembledContext)
mock_result.content = "Context with knowledge"
mock_result.total_tokens = 1000
mock_result.context_count = 3
mock_result.excluded_count = 0
mock_result.assembly_time_ms = 100.0
mock_result.metadata = {
"compressed_contexts": 1
} # Indicates compression happened
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=4000,
system=500,
knowledge=1500,
conversation=1000,
tools=500,
response_reserve=500,
)
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "How does authentication work?",
"knowledge_query": "authentication flow implementation",
"knowledge_limit": 5,
},
)
assert response.status_code == status.HTTP_200_OK
call_kwargs = mock_context_engine.assemble_context.call_args.kwargs
assert call_kwargs["knowledge_query"] == "authentication flow implementation"
assert call_kwargs["knowledge_limit"] == 5