diff --git a/backend/tests/api/routes/test_context.py b/backend/tests/api/routes/test_context.py new file mode 100644 index 0000000..1cbcd90 --- /dev/null +++ b/backend/tests/api/routes/test_context.py @@ -0,0 +1,466 @@ +""" +Tests for Context Management API Routes. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from app.main import app +from app.models.user import User +from app.services.context import ( + AssembledContext, + AssemblyTimeoutError, + BudgetExceededError, + ContextEngine, + TokenBudget, +) +from app.services.mcp import MCPClientManager + + +@pytest.fixture +def mock_mcp_client(): + """Create a mock MCP client manager.""" + client = MagicMock(spec=MCPClientManager) + client.is_initialized = True + return client + + +@pytest.fixture +def mock_context_engine(mock_mcp_client): + """Create a mock ContextEngine.""" + engine = MagicMock(spec=ContextEngine) + engine._mcp = mock_mcp_client + return engine + + +@pytest.fixture +def mock_superuser(): + """Create a mock superuser.""" + user = MagicMock(spec=User) + user.id = "00000000-0000-0000-0000-000000000001" + user.is_superuser = True + user.email = "admin@example.com" + return user + + +@pytest.fixture +def client(mock_mcp_client, mock_context_engine, mock_superuser): + """Create a FastAPI test client with mocked dependencies.""" + from app.api.dependencies.permissions import require_superuser + from app.api.routes.context import get_context_engine + from app.services.mcp import get_mcp_client + + # Override dependencies + async def override_get_mcp_client(): + return mock_mcp_client + + async def override_get_context_engine(): + return mock_context_engine + + async def override_require_superuser(): + return mock_superuser + + app.dependency_overrides[get_mcp_client] = override_get_mcp_client + app.dependency_overrides[get_context_engine] = override_get_context_engine + app.dependency_overrides[require_superuser] = override_require_superuser + + with patch("app.main.check_database_health", return_value=True): + yield TestClient(app) + + # Clean up + app.dependency_overrides.clear() + + +class TestContextHealth: + """Tests for GET /context/health endpoint.""" + + def test_health_check_success(self, client, mock_context_engine, mock_mcp_client): + """Test context engine health check.""" + mock_context_engine.get_stats = AsyncMock( + return_value={ + "cache": {"hits": 10, "misses": 5}, + "settings": {"cache_enabled": True}, + } + ) + + response = client.get("/api/v1/context/health") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["status"] == "healthy" + assert "mcp_connected" in data + assert "cache_enabled" in data + + +class TestAssembleContext: + """Tests for POST /context/assemble endpoint.""" + + def test_assemble_context_success(self, client, mock_context_engine): + """Test successful context assembly.""" + # Create mock assembled context + mock_result = MagicMock(spec=AssembledContext) + mock_result.content = "Assembled context content" + mock_result.total_tokens = 500 + mock_result.context_count = 2 + mock_result.excluded_count = 0 + mock_result.assembly_time_ms = 50.5 + mock_result.metadata = {} + + mock_context_engine.assemble_context = AsyncMock(return_value=mock_result) + mock_context_engine.get_budget_for_model = AsyncMock( + return_value=TokenBudget( + total=4000, + system=500, + knowledge=1500, + conversation=1000, + tools=500, + response_reserve=500, + ) + ) + + response = client.post( + "/api/v1/context/assemble", + json={ + "project_id": "test-project", + "agent_id": "test-agent", + "query": "What is the auth flow?", + "model": "claude-3-sonnet", + "system_prompt": "You are a helpful assistant.", + }, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["content"] == "Assembled context content" + assert data["total_tokens"] == 500 + assert data["context_count"] == 2 + assert data["compressed"] is False + assert "budget_used_percent" in data + + def test_assemble_context_with_conversation(self, client, mock_context_engine): + """Test context assembly with conversation history.""" + mock_result = MagicMock(spec=AssembledContext) + mock_result.content = "Context with history" + mock_result.total_tokens = 800 + mock_result.context_count = 1 + mock_result.excluded_count = 0 + mock_result.assembly_time_ms = 30.0 + mock_result.metadata = {} + + mock_context_engine.assemble_context = AsyncMock(return_value=mock_result) + mock_context_engine.get_budget_for_model = AsyncMock( + return_value=TokenBudget( + total=4000, + system=500, + knowledge=1500, + conversation=1000, + tools=500, + response_reserve=500, + ) + ) + + response = client.post( + "/api/v1/context/assemble", + json={ + "project_id": "test-project", + "agent_id": "test-agent", + "query": "Continue the discussion", + "conversation_history": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + }, + ) + + assert response.status_code == status.HTTP_200_OK + call_args = mock_context_engine.assemble_context.call_args + assert call_args.kwargs["conversation_history"] == [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + def test_assemble_context_with_tool_results(self, client, mock_context_engine): + """Test context assembly with tool results.""" + mock_result = MagicMock(spec=AssembledContext) + mock_result.content = "Context with tools" + mock_result.total_tokens = 600 + mock_result.context_count = 1 + mock_result.excluded_count = 0 + mock_result.assembly_time_ms = 25.0 + mock_result.metadata = {} + + mock_context_engine.assemble_context = AsyncMock(return_value=mock_result) + mock_context_engine.get_budget_for_model = AsyncMock( + return_value=TokenBudget( + total=4000, + system=500, + knowledge=1500, + conversation=1000, + tools=500, + response_reserve=500, + ) + ) + + response = client.post( + "/api/v1/context/assemble", + json={ + "project_id": "test-project", + "agent_id": "test-agent", + "query": "What did the search find?", + "tool_results": [ + { + "tool_name": "search_knowledge", + "content": {"results": ["item1", "item2"]}, + "status": "success", + } + ], + }, + ) + + assert response.status_code == status.HTTP_200_OK + call_args = mock_context_engine.assemble_context.call_args + assert len(call_args.kwargs["tool_results"]) == 1 + + def test_assemble_context_timeout(self, client, mock_context_engine): + """Test context assembly timeout error.""" + mock_context_engine.assemble_context = AsyncMock( + side_effect=AssemblyTimeoutError("Assembly exceeded 5000ms limit") + ) + + response = client.post( + "/api/v1/context/assemble", + json={ + "project_id": "test-project", + "agent_id": "test-agent", + "query": "test", + }, + ) + + assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT + + def test_assemble_context_budget_exceeded(self, client, mock_context_engine): + """Test context assembly budget exceeded error.""" + mock_context_engine.assemble_context = AsyncMock( + side_effect=BudgetExceededError("Token budget exceeded: 5000 > 4000") + ) + + response = client.post( + "/api/v1/context/assemble", + json={ + "project_id": "test-project", + "agent_id": "test-agent", + "query": "test", + }, + ) + + assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE + + def test_assemble_context_validation_error(self, client): + """Test context assembly with invalid request.""" + response = client.post( + "/api/v1/context/assemble", + json={}, # Missing required fields + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +class TestCountTokens: + """Tests for POST /context/count-tokens endpoint.""" + + def test_count_tokens_success(self, client, mock_context_engine): + """Test successful token counting.""" + mock_context_engine.count_tokens = AsyncMock(return_value=42) + + response = client.post( + "/api/v1/context/count-tokens", + json={ + "content": "This is some test content.", + "model": "claude-3-sonnet", + }, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["token_count"] == 42 + assert data["model"] == "claude-3-sonnet" + + def test_count_tokens_without_model(self, client, mock_context_engine): + """Test token counting without specifying model.""" + mock_context_engine.count_tokens = AsyncMock(return_value=100) + + response = client.post( + "/api/v1/context/count-tokens", + json={"content": "Some content to count."}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["token_count"] == 100 + assert data["model"] is None + + +class TestGetBudget: + """Tests for GET /context/budget/{model} endpoint.""" + + def test_get_budget_success(self, client, mock_context_engine): + """Test getting token budget for a model.""" + mock_context_engine.get_budget_for_model = AsyncMock( + return_value=TokenBudget( + total=100000, + system=10000, + knowledge=40000, + conversation=30000, + tools=10000, + response_reserve=10000, + ) + ) + + response = client.get("/api/v1/context/budget/claude-3-opus") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["model"] == "claude-3-opus" + assert data["total_tokens"] == 100000 + assert data["system_tokens"] == 10000 + assert data["knowledge_tokens"] == 40000 + + def test_get_budget_with_max_tokens(self, client, mock_context_engine): + """Test getting budget with custom max tokens.""" + mock_context_engine.get_budget_for_model = AsyncMock( + return_value=TokenBudget( + total=2000, + system=200, + knowledge=800, + conversation=600, + tools=200, + response_reserve=200, + ) + ) + + response = client.get("/api/v1/context/budget/gpt-4?max_tokens=2000") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["total_tokens"] == 2000 + + +class TestGetStats: + """Tests for GET /context/stats endpoint.""" + + def test_get_stats_success(self, client, mock_context_engine): + """Test getting engine statistics.""" + mock_context_engine.get_stats = AsyncMock( + return_value={ + "cache": { + "hits": 100, + "misses": 25, + "hit_rate": 0.8, + }, + "settings": { + "compression_threshold": 0.9, + "max_assembly_time_ms": 5000, + "cache_enabled": True, + }, + } + ) + + response = client.get("/api/v1/context/stats") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["cache"]["hits"] == 100 + assert data["settings"]["cache_enabled"] is True + + +class TestInvalidateCache: + """Tests for POST /context/cache/invalidate endpoint.""" + + def test_invalidate_cache_by_project(self, client, mock_context_engine): + """Test cache invalidation by project ID.""" + mock_context_engine.invalidate_cache = AsyncMock(return_value=5) + + response = client.post( + "/api/v1/context/cache/invalidate?project_id=test-project" + ) + + assert response.status_code == status.HTTP_204_NO_CONTENT + mock_context_engine.invalidate_cache.assert_called_once() + call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs + assert call_kwargs["project_id"] == "test-project" + + def test_invalidate_cache_by_pattern(self, client, mock_context_engine): + """Test cache invalidation by pattern.""" + mock_context_engine.invalidate_cache = AsyncMock(return_value=10) + + response = client.post("/api/v1/context/cache/invalidate?pattern=*auth*") + + assert response.status_code == status.HTTP_204_NO_CONTENT + mock_context_engine.invalidate_cache.assert_called_once() + call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs + assert call_kwargs["pattern"] == "*auth*" + + def test_invalidate_cache_all(self, client, mock_context_engine): + """Test invalidating all cache entries.""" + mock_context_engine.invalidate_cache = AsyncMock(return_value=100) + + response = client.post("/api/v1/context/cache/invalidate") + + assert response.status_code == status.HTTP_204_NO_CONTENT + + +class TestContextEndpointsEdgeCases: + """Edge case tests for Context endpoints.""" + + def test_context_content_type(self, client, mock_context_engine): + """Test that endpoints return JSON content type.""" + mock_context_engine.get_stats = AsyncMock( + return_value={"cache": {}, "settings": {}} + ) + + response = client.get("/api/v1/context/health") + + assert "application/json" in response.headers["content-type"] + + def test_assemble_context_with_knowledge_query(self, client, mock_context_engine): + """Test context assembly with knowledge base query.""" + mock_result = MagicMock(spec=AssembledContext) + mock_result.content = "Context with knowledge" + mock_result.total_tokens = 1000 + mock_result.context_count = 3 + mock_result.excluded_count = 0 + mock_result.assembly_time_ms = 100.0 + mock_result.metadata = { + "compressed_contexts": 1 + } # Indicates compression happened + + mock_context_engine.assemble_context = AsyncMock(return_value=mock_result) + mock_context_engine.get_budget_for_model = AsyncMock( + return_value=TokenBudget( + total=4000, + system=500, + knowledge=1500, + conversation=1000, + tools=500, + response_reserve=500, + ) + ) + + response = client.post( + "/api/v1/context/assemble", + json={ + "project_id": "test-project", + "agent_id": "test-agent", + "query": "How does authentication work?", + "knowledge_query": "authentication flow implementation", + "knowledge_limit": 5, + }, + ) + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_context_engine.assemble_context.call_args.kwargs + assert call_kwargs["knowledge_query"] == "authentication flow implementation" + assert call_kwargs["knowledge_limit"] == 5