""" Tests for Context Management API Routes. """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import status from fastapi.testclient import TestClient from app.main import app from app.models.user import User from app.services.context import ( AssembledContext, AssemblyTimeoutError, BudgetExceededError, ContextEngine, TokenBudget, ) from app.services.mcp import MCPClientManager @pytest.fixture def mock_mcp_client(): """Create a mock MCP client manager.""" client = MagicMock(spec=MCPClientManager) client.is_initialized = True return client @pytest.fixture def mock_context_engine(mock_mcp_client): """Create a mock ContextEngine.""" engine = MagicMock(spec=ContextEngine) engine._mcp = mock_mcp_client return engine @pytest.fixture def mock_superuser(): """Create a mock superuser.""" user = MagicMock(spec=User) user.id = "00000000-0000-0000-0000-000000000001" user.is_superuser = True user.email = "admin@example.com" return user @pytest.fixture def client(mock_mcp_client, mock_context_engine, mock_superuser): """Create a FastAPI test client with mocked dependencies.""" from app.api.dependencies.permissions import require_superuser from app.api.routes.context import get_context_engine from app.services.mcp import get_mcp_client # Override dependencies async def override_get_mcp_client(): return mock_mcp_client async def override_get_context_engine(): return mock_context_engine async def override_require_superuser(): return mock_superuser app.dependency_overrides[get_mcp_client] = override_get_mcp_client app.dependency_overrides[get_context_engine] = override_get_context_engine app.dependency_overrides[require_superuser] = override_require_superuser with patch("app.main.check_database_health", return_value=True): yield TestClient(app) # Clean up app.dependency_overrides.clear() class TestContextHealth: """Tests for GET /context/health endpoint.""" def test_health_check_success(self, client, mock_context_engine, mock_mcp_client): """Test context engine health check.""" mock_context_engine.get_stats = AsyncMock( return_value={ "cache": {"hits": 10, "misses": 5}, "settings": {"cache_enabled": True}, } ) response = client.get("/api/v1/context/health") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["status"] == "healthy" assert "mcp_connected" in data assert "cache_enabled" in data class TestAssembleContext: """Tests for POST /context/assemble endpoint.""" def test_assemble_context_success(self, client, mock_context_engine): """Test successful context assembly.""" # Create mock assembled context mock_result = MagicMock(spec=AssembledContext) mock_result.content = "Assembled context content" mock_result.total_tokens = 500 mock_result.context_count = 2 mock_result.excluded_count = 0 mock_result.assembly_time_ms = 50.5 mock_result.metadata = {} mock_context_engine.assemble_context = AsyncMock(return_value=mock_result) mock_context_engine.get_budget_for_model = AsyncMock( return_value=TokenBudget( total=4000, system=500, knowledge=1500, conversation=1000, tools=500, response_reserve=500, ) ) response = client.post( "/api/v1/context/assemble", json={ "project_id": "test-project", "agent_id": "test-agent", "query": "What is the auth flow?", "model": "claude-3-sonnet", "system_prompt": "You are a helpful assistant.", }, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["content"] == "Assembled context content" assert data["total_tokens"] == 500 assert data["context_count"] == 2 assert data["compressed"] is False assert "budget_used_percent" in data def test_assemble_context_with_conversation(self, client, mock_context_engine): """Test context assembly with conversation history.""" mock_result = MagicMock(spec=AssembledContext) mock_result.content = "Context with history" mock_result.total_tokens = 800 mock_result.context_count = 1 mock_result.excluded_count = 0 mock_result.assembly_time_ms = 30.0 mock_result.metadata = {} mock_context_engine.assemble_context = AsyncMock(return_value=mock_result) mock_context_engine.get_budget_for_model = AsyncMock( return_value=TokenBudget( total=4000, system=500, knowledge=1500, conversation=1000, tools=500, response_reserve=500, ) ) response = client.post( "/api/v1/context/assemble", json={ "project_id": "test-project", "agent_id": "test-agent", "query": "Continue the discussion", "conversation_history": [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ], }, ) assert response.status_code == status.HTTP_200_OK call_args = mock_context_engine.assemble_context.call_args assert call_args.kwargs["conversation_history"] == [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] def test_assemble_context_with_tool_results(self, client, mock_context_engine): """Test context assembly with tool results.""" mock_result = MagicMock(spec=AssembledContext) mock_result.content = "Context with tools" mock_result.total_tokens = 600 mock_result.context_count = 1 mock_result.excluded_count = 0 mock_result.assembly_time_ms = 25.0 mock_result.metadata = {} mock_context_engine.assemble_context = AsyncMock(return_value=mock_result) mock_context_engine.get_budget_for_model = AsyncMock( return_value=TokenBudget( total=4000, system=500, knowledge=1500, conversation=1000, tools=500, response_reserve=500, ) ) response = client.post( "/api/v1/context/assemble", json={ "project_id": "test-project", "agent_id": "test-agent", "query": "What did the search find?", "tool_results": [ { "tool_name": "search_knowledge", "content": {"results": ["item1", "item2"]}, "status": "success", } ], }, ) assert response.status_code == status.HTTP_200_OK call_args = mock_context_engine.assemble_context.call_args assert len(call_args.kwargs["tool_results"]) == 1 def test_assemble_context_timeout(self, client, mock_context_engine): """Test context assembly timeout error.""" mock_context_engine.assemble_context = AsyncMock( side_effect=AssemblyTimeoutError("Assembly exceeded 5000ms limit") ) response = client.post( "/api/v1/context/assemble", json={ "project_id": "test-project", "agent_id": "test-agent", "query": "test", }, ) assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT def test_assemble_context_budget_exceeded(self, client, mock_context_engine): """Test context assembly budget exceeded error.""" mock_context_engine.assemble_context = AsyncMock( side_effect=BudgetExceededError("Token budget exceeded: 5000 > 4000") ) response = client.post( "/api/v1/context/assemble", json={ "project_id": "test-project", "agent_id": "test-agent", "query": "test", }, ) assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE def test_assemble_context_validation_error(self, client): """Test context assembly with invalid request.""" response = client.post( "/api/v1/context/assemble", json={}, # Missing required fields ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY class TestCountTokens: """Tests for POST /context/count-tokens endpoint.""" def test_count_tokens_success(self, client, mock_context_engine): """Test successful token counting.""" mock_context_engine.count_tokens = AsyncMock(return_value=42) response = client.post( "/api/v1/context/count-tokens", json={ "content": "This is some test content.", "model": "claude-3-sonnet", }, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["token_count"] == 42 assert data["model"] == "claude-3-sonnet" def test_count_tokens_without_model(self, client, mock_context_engine): """Test token counting without specifying model.""" mock_context_engine.count_tokens = AsyncMock(return_value=100) response = client.post( "/api/v1/context/count-tokens", json={"content": "Some content to count."}, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["token_count"] == 100 assert data["model"] is None class TestGetBudget: """Tests for GET /context/budget/{model} endpoint.""" def test_get_budget_success(self, client, mock_context_engine): """Test getting token budget for a model.""" mock_context_engine.get_budget_for_model = AsyncMock( return_value=TokenBudget( total=100000, system=10000, knowledge=40000, conversation=30000, tools=10000, response_reserve=10000, ) ) response = client.get("/api/v1/context/budget/claude-3-opus") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["model"] == "claude-3-opus" assert data["total_tokens"] == 100000 assert data["system_tokens"] == 10000 assert data["knowledge_tokens"] == 40000 def test_get_budget_with_max_tokens(self, client, mock_context_engine): """Test getting budget with custom max tokens.""" mock_context_engine.get_budget_for_model = AsyncMock( return_value=TokenBudget( total=2000, system=200, knowledge=800, conversation=600, tools=200, response_reserve=200, ) ) response = client.get("/api/v1/context/budget/gpt-4?max_tokens=2000") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total_tokens"] == 2000 class TestGetStats: """Tests for GET /context/stats endpoint.""" def test_get_stats_success(self, client, mock_context_engine): """Test getting engine statistics.""" mock_context_engine.get_stats = AsyncMock( return_value={ "cache": { "hits": 100, "misses": 25, "hit_rate": 0.8, }, "settings": { "compression_threshold": 0.9, "max_assembly_time_ms": 5000, "cache_enabled": True, }, } ) response = client.get("/api/v1/context/stats") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["cache"]["hits"] == 100 assert data["settings"]["cache_enabled"] is True class TestInvalidateCache: """Tests for POST /context/cache/invalidate endpoint.""" def test_invalidate_cache_by_project(self, client, mock_context_engine): """Test cache invalidation by project ID.""" mock_context_engine.invalidate_cache = AsyncMock(return_value=5) response = client.post( "/api/v1/context/cache/invalidate?project_id=test-project" ) assert response.status_code == status.HTTP_204_NO_CONTENT mock_context_engine.invalidate_cache.assert_called_once() call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs assert call_kwargs["project_id"] == "test-project" def test_invalidate_cache_by_pattern(self, client, mock_context_engine): """Test cache invalidation by pattern.""" mock_context_engine.invalidate_cache = AsyncMock(return_value=10) response = client.post("/api/v1/context/cache/invalidate?pattern=*auth*") assert response.status_code == status.HTTP_204_NO_CONTENT mock_context_engine.invalidate_cache.assert_called_once() call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs assert call_kwargs["pattern"] == "*auth*" def test_invalidate_cache_all(self, client, mock_context_engine): """Test invalidating all cache entries.""" mock_context_engine.invalidate_cache = AsyncMock(return_value=100) response = client.post("/api/v1/context/cache/invalidate") assert response.status_code == status.HTTP_204_NO_CONTENT class TestContextEndpointsEdgeCases: """Edge case tests for Context endpoints.""" def test_context_content_type(self, client, mock_context_engine): """Test that endpoints return JSON content type.""" mock_context_engine.get_stats = AsyncMock( return_value={"cache": {}, "settings": {}} ) response = client.get("/api/v1/context/health") assert "application/json" in response.headers["content-type"] def test_assemble_context_with_knowledge_query(self, client, mock_context_engine): """Test context assembly with knowledge base query.""" mock_result = MagicMock(spec=AssembledContext) mock_result.content = "Context with knowledge" mock_result.total_tokens = 1000 mock_result.context_count = 3 mock_result.excluded_count = 0 mock_result.assembly_time_ms = 100.0 mock_result.metadata = { "compressed_contexts": 1 } # Indicates compression happened mock_context_engine.assemble_context = AsyncMock(return_value=mock_result) mock_context_engine.get_budget_for_model = AsyncMock( return_value=TokenBudget( total=4000, system=500, knowledge=1500, conversation=1000, tools=500, response_reserve=500, ) ) response = client.post( "/api/v1/context/assemble", json={ "project_id": "test-project", "agent_id": "test-agent", "query": "How does authentication work?", "knowledge_query": "authentication flow implementation", "knowledge_limit": 5, }, ) assert response.status_code == status.HTTP_200_OK call_kwargs = mock_context_engine.assemble_context.call_args.kwargs assert call_kwargs["knowledge_query"] == "authentication flow implementation" assert call_kwargs["knowledge_limit"] == 5