"""Tests for ContextEngine.""" from unittest.mock import AsyncMock, MagicMock import pytest from app.services.context.config import ContextSettings from app.services.context.engine import ContextEngine, create_context_engine from app.services.context.types import ( AssembledContext, ConversationContext, KnowledgeContext, MessageRole, SystemContext, TaskContext, ToolContext, ) class TestContextEngineCreation: """Tests for ContextEngine creation.""" def test_creation_minimal(self) -> None: """Test creating engine with minimal config.""" engine = ContextEngine() assert engine._mcp is None assert engine._settings is not None assert engine._calculator is not None assert engine._scorer is not None assert engine._ranker is not None assert engine._compressor is not None assert engine._cache is not None assert engine._pipeline is not None def test_creation_with_settings(self) -> None: """Test creating engine with custom settings.""" settings = ContextSettings( compression_threshold=0.7, cache_enabled=False, ) engine = ContextEngine(settings=settings) assert engine._settings.compression_threshold == 0.7 assert engine._settings.cache_enabled is False def test_creation_with_redis(self) -> None: """Test creating engine with Redis.""" mock_redis = MagicMock() settings = ContextSettings(cache_enabled=True) engine = ContextEngine(redis=mock_redis, settings=settings) assert engine._cache.is_enabled def test_set_mcp_manager(self) -> None: """Test setting MCP manager.""" engine = ContextEngine() mock_mcp = MagicMock() engine.set_mcp_manager(mock_mcp) assert engine._mcp is mock_mcp def test_set_redis(self) -> None: """Test setting Redis connection.""" engine = ContextEngine() mock_redis = MagicMock() engine.set_redis(mock_redis) assert engine._cache._redis is mock_redis class TestContextEngineHelpers: """Tests for ContextEngine helper methods.""" def test_convert_conversation(self) -> None: """Test converting conversation history.""" engine = ContextEngine() history = [ {"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hi there!"}, {"role": "user", "content": "How are you?"}, ] contexts = engine._convert_conversation(history) assert len(contexts) == 3 assert all(isinstance(c, ConversationContext) for c in contexts) assert contexts[0].role == MessageRole.USER assert contexts[1].role == MessageRole.ASSISTANT assert contexts[0].content == "Hello!" assert contexts[0].metadata["turn"] == 0 def test_convert_tool_results(self) -> None: """Test converting tool results.""" engine = ContextEngine() results = [ {"tool_name": "search", "content": "Result 1", "status": "success"}, {"tool_name": "read", "result": {"file": "test.txt"}, "status": "success"}, ] contexts = engine._convert_tool_results(results) assert len(contexts) == 2 assert all(isinstance(c, ToolContext) for c in contexts) assert contexts[0].content == "Result 1" assert contexts[0].metadata["tool_name"] == "search" # Dict content should be JSON serialized assert "file" in contexts[1].content assert "test.txt" in contexts[1].content class TestContextEngineAssembly: """Tests for context assembly.""" @pytest.mark.asyncio async def test_assemble_minimal(self) -> None: """Test assembling with minimal inputs.""" engine = ContextEngine() result = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="test query", model="claude-3-sonnet", use_cache=False, # Disable cache for test ) assert isinstance(result, AssembledContext) assert result.context_count == 0 # No contexts provided @pytest.mark.asyncio async def test_assemble_with_system_prompt(self) -> None: """Test assembling with system prompt.""" engine = ContextEngine() result = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="test query", model="claude-3-sonnet", system_prompt="You are a helpful assistant.", use_cache=False, ) assert result.context_count == 1 assert "helpful assistant" in result.content @pytest.mark.asyncio async def test_assemble_with_task(self) -> None: """Test assembling with task description.""" engine = ContextEngine() result = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="implement feature", model="claude-3-sonnet", task_description="Implement user authentication", use_cache=False, ) assert result.context_count == 1 assert "authentication" in result.content @pytest.mark.asyncio async def test_assemble_with_conversation(self) -> None: """Test assembling with conversation history.""" engine = ContextEngine() result = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="continue", model="claude-3-sonnet", conversation_history=[ {"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hi!"}, ], use_cache=False, ) assert result.context_count == 2 assert "Hello" in result.content @pytest.mark.asyncio async def test_assemble_with_tool_results(self) -> None: """Test assembling with tool results.""" engine = ContextEngine() result = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="continue", model="claude-3-sonnet", tool_results=[ {"tool_name": "search", "content": "Found 5 results"}, ], use_cache=False, ) assert result.context_count == 1 assert "Found 5 results" in result.content @pytest.mark.asyncio async def test_assemble_with_custom_contexts(self) -> None: """Test assembling with custom contexts.""" engine = ContextEngine() custom = [ KnowledgeContext( content="Custom knowledge.", source="custom", relevance_score=0.9, ) ] result = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="test", model="claude-3-sonnet", custom_contexts=custom, use_cache=False, ) assert result.context_count == 1 assert "Custom knowledge" in result.content @pytest.mark.asyncio async def test_assemble_full_workflow(self) -> None: """Test full assembly workflow.""" engine = ContextEngine() result = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="implement login", model="claude-3-sonnet", system_prompt="You are an expert Python developer.", task_description="Implement user authentication.", conversation_history=[ {"role": "user", "content": "Can you help me implement JWT auth?"}, ], tool_results=[ {"tool_name": "file_create", "content": "Created auth.py"}, ], use_cache=False, ) assert result.context_count >= 4 assert result.total_tokens > 0 assert result.model == "claude-3-sonnet" # Check for expected content assert "expert Python developer" in result.content assert "authentication" in result.content class TestContextEngineKnowledge: """Tests for knowledge fetching.""" @pytest.mark.asyncio async def test_fetch_knowledge_no_mcp(self) -> None: """Test fetching knowledge without MCP returns empty.""" engine = ContextEngine() result = await engine._fetch_knowledge( project_id="proj-123", agent_id="agent-456", query="test", ) assert result == [] @pytest.mark.asyncio async def test_fetch_knowledge_with_mcp(self) -> None: """Test fetching knowledge with MCP.""" mock_mcp = AsyncMock() mock_mcp.call_tool.return_value.data = { "results": [ { "content": "Document content", "source_path": "docs/api.md", "score": 0.9, "chunk_id": "chunk-1", }, { "content": "Another document", "source_path": "docs/auth.md", "score": 0.8, }, ] } engine = ContextEngine(mcp_manager=mock_mcp) result = await engine._fetch_knowledge( project_id="proj-123", agent_id="agent-456", query="authentication", ) assert len(result) == 2 assert all(isinstance(c, KnowledgeContext) for c in result) assert result[0].content == "Document content" assert result[0].source == "docs/api.md" assert result[0].relevance_score == 0.9 @pytest.mark.asyncio async def test_fetch_knowledge_error_handling(self) -> None: """Test knowledge fetch error handling.""" mock_mcp = AsyncMock() mock_mcp.call_tool.side_effect = Exception("MCP error") engine = ContextEngine(mcp_manager=mock_mcp) # Should not raise, returns empty result = await engine._fetch_knowledge( project_id="proj-123", agent_id="agent-456", query="test", ) assert result == [] class TestContextEngineCaching: """Tests for caching behavior.""" @pytest.mark.asyncio async def test_cache_disabled(self) -> None: """Test assembly with cache disabled.""" engine = ContextEngine() result = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="test", model="claude-3-sonnet", system_prompt="Test prompt", use_cache=False, ) assert not result.cache_hit @pytest.mark.asyncio async def test_cache_hit(self) -> None: """Test cache hit.""" mock_redis = AsyncMock() settings = ContextSettings(cache_enabled=True) engine = ContextEngine(redis=mock_redis, settings=settings) # First call - cache miss mock_redis.get.return_value = None result1 = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="test", model="claude-3-sonnet", system_prompt="Test prompt", ) # Second call - mock cache hit mock_redis.get.return_value = result1.to_json() result2 = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="test", model="claude-3-sonnet", system_prompt="Test prompt", ) assert result2.cache_hit class TestContextEngineUtilities: """Tests for utility methods.""" @pytest.mark.asyncio async def test_get_budget_for_model(self) -> None: """Test getting budget for model.""" engine = ContextEngine() budget = await engine.get_budget_for_model("claude-3-sonnet") assert budget.total > 0 assert budget.system > 0 assert budget.knowledge > 0 @pytest.mark.asyncio async def test_get_budget_with_max_tokens(self) -> None: """Test getting budget with max tokens.""" engine = ContextEngine() budget = await engine.get_budget_for_model("claude-3-sonnet", max_tokens=5000) assert budget.total == 5000 @pytest.mark.asyncio async def test_count_tokens(self) -> None: """Test token counting.""" engine = ContextEngine() count = await engine.count_tokens("Hello world") assert count > 0 @pytest.mark.asyncio async def test_invalidate_cache(self) -> None: """Test cache invalidation.""" mock_redis = AsyncMock() async def mock_scan_iter(match=None): for key in ["ctx:1", "ctx:2"]: yield key mock_redis.scan_iter = mock_scan_iter settings = ContextSettings(cache_enabled=True) engine = ContextEngine(redis=mock_redis, settings=settings) deleted = await engine.invalidate_cache(pattern="*test*") assert deleted >= 0 @pytest.mark.asyncio async def test_get_stats(self) -> None: """Test getting engine stats.""" engine = ContextEngine() stats = await engine.get_stats() assert "cache" in stats assert "settings" in stats assert "compression_threshold" in stats["settings"] class TestCreateContextEngine: """Tests for factory function.""" def test_create_context_engine(self) -> None: """Test factory function.""" engine = create_context_engine() assert isinstance(engine, ContextEngine) def test_create_context_engine_with_settings(self) -> None: """Test factory with settings.""" settings = ContextSettings(cache_enabled=False) engine = create_context_engine(settings=settings) assert engine._settings.cache_enabled is False