diff --git a/backend/app/services/context/__init__.py b/backend/app/services/context/__init__.py index de5cf51..4d84ef8 100644 --- a/backend/app/services/context/__init__.py +++ b/backend/app/services/context/__init__.py @@ -88,6 +88,9 @@ from .adapters import ( # Cache from .cache import ContextCache +# Engine +from .engine import ContextEngine, create_context_engine + # Prioritization from .prioritization import ( ContextRanker, @@ -137,6 +140,9 @@ __all__ = [ "TokenCalculator", # Cache "ContextCache", + # Engine + "ContextEngine", + "create_context_engine", # Compression "ContextCompressor", "TruncationResult", diff --git a/backend/app/services/context/engine.py b/backend/app/services/context/engine.py new file mode 100644 index 0000000..a70566d --- /dev/null +++ b/backend/app/services/context/engine.py @@ -0,0 +1,470 @@ +""" +Context Management Engine. + +Main orchestration layer for context assembly and optimization. +Provides a high-level API for assembling optimized context for LLM requests. +""" + +import logging +from typing import TYPE_CHECKING, Any + +from .adapters import get_adapter +from .assembly import ContextPipeline +from .budget import BudgetAllocator, TokenBudget, TokenCalculator +from .cache import ContextCache +from .compression import ContextCompressor +from .config import ContextSettings, get_context_settings +from .prioritization import ContextRanker +from .scoring import CompositeScorer +from .types import ( + AssembledContext, + BaseContext, + ConversationContext, + KnowledgeContext, + MessageRole, + SystemContext, + TaskContext, + ToolContext, +) + +if TYPE_CHECKING: + from redis.asyncio import Redis + + from app.services.mcp.client_manager import MCPClientManager + +logger = logging.getLogger(__name__) + + +class ContextEngine: + """ + Main context management engine. + + Provides high-level API for context assembly and optimization. + Integrates all components: scoring, ranking, compression, formatting, and caching. + + Usage: + engine = ContextEngine(mcp_manager=mcp, redis=redis) + + # Assemble context for an LLM request + result = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="implement user authentication", + model="claude-3-sonnet", + system_prompt="You are an expert developer.", + knowledge_query="authentication best practices", + ) + + # Use the assembled context + print(result.content) + print(f"Tokens: {result.total_tokens}") + """ + + def __init__( + self, + mcp_manager: "MCPClientManager | None" = None, + redis: "Redis | None" = None, + settings: ContextSettings | None = None, + ) -> None: + """ + Initialize the context engine. + + Args: + mcp_manager: MCP client manager for LLM Gateway/Knowledge Base + redis: Redis connection for caching + settings: Context settings + """ + self._mcp = mcp_manager + self._settings = settings or get_context_settings() + + # Initialize components + self._calculator = TokenCalculator(mcp_manager=mcp_manager) + self._scorer = CompositeScorer( + mcp_manager=mcp_manager, settings=self._settings + ) + self._ranker = ContextRanker( + scorer=self._scorer, calculator=self._calculator + ) + self._compressor = ContextCompressor(calculator=self._calculator) + self._allocator = BudgetAllocator(self._settings) + self._cache = ContextCache(redis=redis, settings=self._settings) + + # Pipeline for assembly + self._pipeline = ContextPipeline( + mcp_manager=mcp_manager, + settings=self._settings, + calculator=self._calculator, + scorer=self._scorer, + ranker=self._ranker, + compressor=self._compressor, + ) + + def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: + """ + Set MCP manager for all components. + + Args: + mcp_manager: MCP client manager + """ + self._mcp = mcp_manager + self._calculator.set_mcp_manager(mcp_manager) + self._scorer.set_mcp_manager(mcp_manager) + self._pipeline.set_mcp_manager(mcp_manager) + + def set_redis(self, redis: "Redis") -> None: + """ + Set Redis connection for caching. + + Args: + redis: Redis connection + """ + self._cache.set_redis(redis) + + async def assemble_context( + self, + project_id: str, + agent_id: str, + query: str, + model: str, + max_tokens: int | None = None, + system_prompt: str | None = None, + task_description: str | None = None, + knowledge_query: str | None = None, + knowledge_limit: int = 10, + conversation_history: list[dict[str, str]] | None = None, + tool_results: list[dict[str, Any]] | None = None, + custom_contexts: list[BaseContext] | None = None, + custom_budget: TokenBudget | None = None, + compress: bool = True, + format_output: bool = True, + use_cache: bool = True, + ) -> AssembledContext: + """ + Assemble optimized context for an LLM request. + + This is the main entry point for context management. + It gathers context from various sources, scores and ranks them, + compresses if needed, and formats for the target model. + + Args: + project_id: Project identifier + agent_id: Agent identifier + query: User's query or current request + model: Target model name + max_tokens: Maximum context tokens (uses model default if None) + system_prompt: System prompt/instructions + task_description: Current task description + knowledge_query: Query for knowledge base search + knowledge_limit: Max number of knowledge results + conversation_history: List of {"role": str, "content": str} + tool_results: List of tool results to include + custom_contexts: Additional custom contexts + custom_budget: Custom token budget + compress: Whether to apply compression + format_output: Whether to format for the model + use_cache: Whether to use caching + + Returns: + AssembledContext with optimized content + + Raises: + AssemblyTimeoutError: If assembly exceeds timeout + BudgetExceededError: If context exceeds budget + """ + # Gather all contexts + contexts: list[BaseContext] = [] + + # 1. System context + if system_prompt: + contexts.append( + SystemContext( + content=system_prompt, + source="system_prompt", + ) + ) + + # 2. Task context + if task_description: + contexts.append( + TaskContext( + content=task_description, + source=f"task:{project_id}:{agent_id}", + ) + ) + + # 3. Knowledge context from Knowledge Base + if knowledge_query and self._mcp: + knowledge_contexts = await self._fetch_knowledge( + project_id=project_id, + agent_id=agent_id, + query=knowledge_query, + limit=knowledge_limit, + ) + contexts.extend(knowledge_contexts) + + # 4. Conversation history + if conversation_history: + contexts.extend(self._convert_conversation(conversation_history)) + + # 5. Tool results + if tool_results: + contexts.extend(self._convert_tool_results(tool_results)) + + # 6. Custom contexts + if custom_contexts: + contexts.extend(custom_contexts) + + # Check cache if enabled + if use_cache and self._cache.is_enabled: + fingerprint = self._cache.compute_fingerprint(contexts, query, model) + cached = await self._cache.get_assembled(fingerprint) + if cached: + logger.debug(f"Cache hit for context assembly: {fingerprint}") + return cached + + # Run assembly pipeline + result = await self._pipeline.assemble( + contexts=contexts, + query=query, + model=model, + max_tokens=max_tokens, + custom_budget=custom_budget, + compress=compress, + format_output=format_output, + ) + + # Cache result if enabled + if use_cache and self._cache.is_enabled: + fingerprint = self._cache.compute_fingerprint(contexts, query, model) + await self._cache.set_assembled(fingerprint, result) + + return result + + async def _fetch_knowledge( + self, + project_id: str, + agent_id: str, + query: str, + limit: int = 10, + ) -> list[KnowledgeContext]: + """ + Fetch relevant knowledge from Knowledge Base via MCP. + + Args: + project_id: Project identifier + agent_id: Agent identifier + query: Search query + limit: Maximum results + + Returns: + List of KnowledgeContext instances + """ + if not self._mcp: + return [] + + try: + result = await self._mcp.call_tool( + "knowledge-base", + "search_knowledge", + { + "project_id": project_id, + "agent_id": agent_id, + "query": query, + "search_type": "hybrid", + "limit": limit, + }, + ) + + contexts = [] + for chunk in result.data.get("results", []): + contexts.append( + KnowledgeContext( + content=chunk.get("content", ""), + source=chunk.get("source_path", "unknown"), + relevance_score=chunk.get("score", 0.0), + metadata={ + "chunk_id": chunk.get("chunk_id"), + "document_id": chunk.get("document_id"), + }, + ) + ) + + logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}") + return contexts + + except Exception as e: + logger.warning(f"Failed to fetch knowledge: {e}") + return [] + + def _convert_conversation( + self, + history: list[dict[str, str]], + ) -> list[ConversationContext]: + """ + Convert conversation history to ConversationContext instances. + + Args: + history: List of {"role": str, "content": str} + + Returns: + List of ConversationContext instances + """ + contexts = [] + for i, turn in enumerate(history): + role_str = turn.get("role", "user").lower() + role = MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER + + contexts.append( + ConversationContext( + content=turn.get("content", ""), + source=f"conversation:{i}", + role=role, + metadata={"role": role_str, "turn": i}, + ) + ) + + return contexts + + def _convert_tool_results( + self, + results: list[dict[str, Any]], + ) -> list[ToolContext]: + """ + Convert tool results to ToolContext instances. + + Args: + results: List of tool result dictionaries + + Returns: + List of ToolContext instances + """ + contexts = [] + for result in results: + tool_name = result.get("tool_name", "unknown") + content = result.get("content", result.get("result", "")) + + # Handle dict content + if isinstance(content, dict): + import json + content = json.dumps(content, indent=2) + + contexts.append( + ToolContext( + content=str(content), + source=f"tool:{tool_name}", + metadata={ + "tool_name": tool_name, + "status": result.get("status", "success"), + }, + ) + ) + + return contexts + + async def get_budget_for_model( + self, + model: str, + max_tokens: int | None = None, + ) -> TokenBudget: + """ + Get the token budget for a specific model. + + Args: + model: Model name + max_tokens: Optional max tokens override + + Returns: + TokenBudget instance + """ + if max_tokens: + return self._allocator.create_budget(max_tokens) + return self._allocator.create_budget_for_model(model) + + async def count_tokens( + self, + content: str, + model: str | None = None, + ) -> int: + """ + Count tokens in content. + + Args: + content: Content to count + model: Model for model-specific tokenization + + Returns: + Token count + """ + # Check cache first + cached = await self._cache.get_token_count(content, model) + if cached is not None: + return cached + + count = await self._calculator.count_tokens(content, model) + + # Cache the result + await self._cache.set_token_count(content, count, model) + + return count + + async def invalidate_cache( + self, + project_id: str | None = None, + pattern: str | None = None, + ) -> int: + """ + Invalidate cache entries. + + Args: + project_id: Invalidate all cache for a project + pattern: Custom pattern to match + + Returns: + Number of entries invalidated + """ + if pattern: + return await self._cache.invalidate(pattern) + elif project_id: + return await self._cache.invalidate(f"*{project_id}*") + else: + return await self._cache.clear_all() + + async def get_stats(self) -> dict[str, Any]: + """ + Get engine statistics. + + Returns: + Dictionary with engine stats + """ + return { + "cache": await self._cache.get_stats(), + "settings": { + "compression_threshold": self._settings.compression_threshold, + "max_assembly_time_ms": self._settings.max_assembly_time_ms, + "cache_enabled": self._settings.cache_enabled, + }, + } + + +# Convenience factory function +def create_context_engine( + mcp_manager: "MCPClientManager | None" = None, + redis: "Redis | None" = None, + settings: ContextSettings | None = None, +) -> ContextEngine: + """ + Create a context engine instance. + + Args: + mcp_manager: MCP client manager + redis: Redis connection + settings: Context settings + + Returns: + Configured ContextEngine instance + """ + return ContextEngine( + mcp_manager=mcp_manager, + redis=redis, + settings=settings, + ) diff --git a/backend/tests/services/context/test_engine.py b/backend/tests/services/context/test_engine.py new file mode 100644 index 0000000..87202f7 --- /dev/null +++ b/backend/tests/services/context/test_engine.py @@ -0,0 +1,458 @@ +"""Tests for ContextEngine.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.services.context.config import ContextSettings +from app.services.context.engine import ContextEngine, create_context_engine +from app.services.context.types import ( + AssembledContext, + ConversationContext, + KnowledgeContext, + MessageRole, + SystemContext, + TaskContext, + ToolContext, +) + + +class TestContextEngineCreation: + """Tests for ContextEngine creation.""" + + def test_creation_minimal(self) -> None: + """Test creating engine with minimal config.""" + engine = ContextEngine() + + assert engine._mcp is None + assert engine._settings is not None + assert engine._calculator is not None + assert engine._scorer is not None + assert engine._ranker is not None + assert engine._compressor is not None + assert engine._cache is not None + assert engine._pipeline is not None + + def test_creation_with_settings(self) -> None: + """Test creating engine with custom settings.""" + settings = ContextSettings( + compression_threshold=0.7, + cache_enabled=False, + ) + engine = ContextEngine(settings=settings) + + assert engine._settings.compression_threshold == 0.7 + assert engine._settings.cache_enabled is False + + def test_creation_with_redis(self) -> None: + """Test creating engine with Redis.""" + mock_redis = MagicMock() + settings = ContextSettings(cache_enabled=True) + engine = ContextEngine(redis=mock_redis, settings=settings) + + assert engine._cache.is_enabled + + def test_set_mcp_manager(self) -> None: + """Test setting MCP manager.""" + engine = ContextEngine() + mock_mcp = MagicMock() + + engine.set_mcp_manager(mock_mcp) + + assert engine._mcp is mock_mcp + + def test_set_redis(self) -> None: + """Test setting Redis connection.""" + engine = ContextEngine() + mock_redis = MagicMock() + + engine.set_redis(mock_redis) + + assert engine._cache._redis is mock_redis + + +class TestContextEngineHelpers: + """Tests for ContextEngine helper methods.""" + + def test_convert_conversation(self) -> None: + """Test converting conversation history.""" + engine = ContextEngine() + + history = [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + + contexts = engine._convert_conversation(history) + + assert len(contexts) == 3 + assert all(isinstance(c, ConversationContext) for c in contexts) + assert contexts[0].role == MessageRole.USER + assert contexts[1].role == MessageRole.ASSISTANT + assert contexts[0].content == "Hello!" + assert contexts[0].metadata["turn"] == 0 + + def test_convert_tool_results(self) -> None: + """Test converting tool results.""" + engine = ContextEngine() + + results = [ + {"tool_name": "search", "content": "Result 1", "status": "success"}, + {"tool_name": "read", "result": {"file": "test.txt"}, "status": "success"}, + ] + + contexts = engine._convert_tool_results(results) + + assert len(contexts) == 2 + assert all(isinstance(c, ToolContext) for c in contexts) + assert contexts[0].content == "Result 1" + assert contexts[0].metadata["tool_name"] == "search" + # Dict content should be JSON serialized + assert "file" in contexts[1].content + assert "test.txt" in contexts[1].content + + +class TestContextEngineAssembly: + """Tests for context assembly.""" + + @pytest.mark.asyncio + async def test_assemble_minimal(self) -> None: + """Test assembling with minimal inputs.""" + engine = ContextEngine() + + result = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="test query", + model="claude-3-sonnet", + use_cache=False, # Disable cache for test + ) + + assert isinstance(result, AssembledContext) + assert result.context_count == 0 # No contexts provided + + @pytest.mark.asyncio + async def test_assemble_with_system_prompt(self) -> None: + """Test assembling with system prompt.""" + engine = ContextEngine() + + result = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="test query", + model="claude-3-sonnet", + system_prompt="You are a helpful assistant.", + use_cache=False, + ) + + assert result.context_count == 1 + assert "helpful assistant" in result.content + + @pytest.mark.asyncio + async def test_assemble_with_task(self) -> None: + """Test assembling with task description.""" + engine = ContextEngine() + + result = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="implement feature", + model="claude-3-sonnet", + task_description="Implement user authentication", + use_cache=False, + ) + + assert result.context_count == 1 + assert "authentication" in result.content + + @pytest.mark.asyncio + async def test_assemble_with_conversation(self) -> None: + """Test assembling with conversation history.""" + engine = ContextEngine() + + result = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="continue", + model="claude-3-sonnet", + conversation_history=[ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi!"}, + ], + use_cache=False, + ) + + assert result.context_count == 2 + assert "Hello" in result.content + + @pytest.mark.asyncio + async def test_assemble_with_tool_results(self) -> None: + """Test assembling with tool results.""" + engine = ContextEngine() + + result = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="continue", + model="claude-3-sonnet", + tool_results=[ + {"tool_name": "search", "content": "Found 5 results"}, + ], + use_cache=False, + ) + + assert result.context_count == 1 + assert "Found 5 results" in result.content + + @pytest.mark.asyncio + async def test_assemble_with_custom_contexts(self) -> None: + """Test assembling with custom contexts.""" + engine = ContextEngine() + + custom = [ + KnowledgeContext( + content="Custom knowledge.", + source="custom", + relevance_score=0.9, + ) + ] + + result = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="test", + model="claude-3-sonnet", + custom_contexts=custom, + use_cache=False, + ) + + assert result.context_count == 1 + assert "Custom knowledge" in result.content + + @pytest.mark.asyncio + async def test_assemble_full_workflow(self) -> None: + """Test full assembly workflow.""" + engine = ContextEngine() + + result = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="implement login", + model="claude-3-sonnet", + system_prompt="You are an expert Python developer.", + task_description="Implement user authentication.", + conversation_history=[ + {"role": "user", "content": "Can you help me implement JWT auth?"}, + ], + tool_results=[ + {"tool_name": "file_create", "content": "Created auth.py"}, + ], + use_cache=False, + ) + + assert result.context_count >= 4 + assert result.total_tokens > 0 + assert result.model == "claude-3-sonnet" + + # Check for expected content + assert "expert Python developer" in result.content + assert "authentication" in result.content + + +class TestContextEngineKnowledge: + """Tests for knowledge fetching.""" + + @pytest.mark.asyncio + async def test_fetch_knowledge_no_mcp(self) -> None: + """Test fetching knowledge without MCP returns empty.""" + engine = ContextEngine() + + result = await engine._fetch_knowledge( + project_id="proj-123", + agent_id="agent-456", + query="test", + ) + + assert result == [] + + @pytest.mark.asyncio + async def test_fetch_knowledge_with_mcp(self) -> None: + """Test fetching knowledge with MCP.""" + mock_mcp = AsyncMock() + mock_mcp.call_tool.return_value.data = { + "results": [ + { + "content": "Document content", + "source_path": "docs/api.md", + "score": 0.9, + "chunk_id": "chunk-1", + }, + { + "content": "Another document", + "source_path": "docs/auth.md", + "score": 0.8, + }, + ] + } + + engine = ContextEngine(mcp_manager=mock_mcp) + + result = await engine._fetch_knowledge( + project_id="proj-123", + agent_id="agent-456", + query="authentication", + ) + + assert len(result) == 2 + assert all(isinstance(c, KnowledgeContext) for c in result) + assert result[0].content == "Document content" + assert result[0].source == "docs/api.md" + assert result[0].relevance_score == 0.9 + + @pytest.mark.asyncio + async def test_fetch_knowledge_error_handling(self) -> None: + """Test knowledge fetch error handling.""" + mock_mcp = AsyncMock() + mock_mcp.call_tool.side_effect = Exception("MCP error") + + engine = ContextEngine(mcp_manager=mock_mcp) + + # Should not raise, returns empty + result = await engine._fetch_knowledge( + project_id="proj-123", + agent_id="agent-456", + query="test", + ) + + assert result == [] + + +class TestContextEngineCaching: + """Tests for caching behavior.""" + + @pytest.mark.asyncio + async def test_cache_disabled(self) -> None: + """Test assembly with cache disabled.""" + engine = ContextEngine() + + result = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="test", + model="claude-3-sonnet", + system_prompt="Test prompt", + use_cache=False, + ) + + assert not result.cache_hit + + @pytest.mark.asyncio + async def test_cache_hit(self) -> None: + """Test cache hit.""" + mock_redis = AsyncMock() + settings = ContextSettings(cache_enabled=True) + engine = ContextEngine(redis=mock_redis, settings=settings) + + # First call - cache miss + mock_redis.get.return_value = None + + result1 = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="test", + model="claude-3-sonnet", + system_prompt="Test prompt", + ) + + # Second call - mock cache hit + mock_redis.get.return_value = result1.to_json() + + result2 = await engine.assemble_context( + project_id="proj-123", + agent_id="agent-456", + query="test", + model="claude-3-sonnet", + system_prompt="Test prompt", + ) + + assert result2.cache_hit + + +class TestContextEngineUtilities: + """Tests for utility methods.""" + + @pytest.mark.asyncio + async def test_get_budget_for_model(self) -> None: + """Test getting budget for model.""" + engine = ContextEngine() + + budget = await engine.get_budget_for_model("claude-3-sonnet") + + assert budget.total > 0 + assert budget.system > 0 + assert budget.knowledge > 0 + + @pytest.mark.asyncio + async def test_get_budget_with_max_tokens(self) -> None: + """Test getting budget with max tokens.""" + engine = ContextEngine() + + budget = await engine.get_budget_for_model("claude-3-sonnet", max_tokens=5000) + + assert budget.total == 5000 + + @pytest.mark.asyncio + async def test_count_tokens(self) -> None: + """Test token counting.""" + engine = ContextEngine() + + count = await engine.count_tokens("Hello world") + + assert count > 0 + + @pytest.mark.asyncio + async def test_invalidate_cache(self) -> None: + """Test cache invalidation.""" + mock_redis = AsyncMock() + + async def mock_scan_iter(match=None): + for key in ["ctx:1", "ctx:2"]: + yield key + + mock_redis.scan_iter = mock_scan_iter + + settings = ContextSettings(cache_enabled=True) + engine = ContextEngine(redis=mock_redis, settings=settings) + + deleted = await engine.invalidate_cache(pattern="*test*") + + assert deleted >= 0 + + @pytest.mark.asyncio + async def test_get_stats(self) -> None: + """Test getting engine stats.""" + engine = ContextEngine() + + stats = await engine.get_stats() + + assert "cache" in stats + assert "settings" in stats + assert "compression_threshold" in stats["settings"] + + +class TestCreateContextEngine: + """Tests for factory function.""" + + def test_create_context_engine(self) -> None: + """Test factory function.""" + engine = create_context_engine() + + assert isinstance(engine, ContextEngine) + + def test_create_context_engine_with_settings(self) -> None: + """Test factory with settings.""" + settings = ContextSettings(cache_enabled=False) + engine = create_context_engine(settings=settings) + + assert engine._settings.cache_enabled is False