""" Context Management Engine. Main orchestration layer for context assembly and optimization. Provides a high-level API for assembling optimized context for LLM requests. """ import logging from typing import TYPE_CHECKING, Any from uuid import UUID from .assembly import ContextPipeline from .budget import BudgetAllocator, TokenBudget, TokenCalculator from .cache import ContextCache from .compression import ContextCompressor from .config import ContextSettings, get_context_settings from .prioritization import ContextRanker from .scoring import CompositeScorer from .types import ( AssembledContext, BaseContext, ConversationContext, KnowledgeContext, MemoryContext, MessageRole, SystemContext, TaskContext, ToolContext, ) if TYPE_CHECKING: from redis.asyncio import Redis from app.services.mcp.client_manager import MCPClientManager from app.services.memory.integration import MemoryContextSource logger = logging.getLogger(__name__) class ContextEngine: """ Main context management engine. Provides high-level API for context assembly and optimization. Integrates all components: scoring, ranking, compression, formatting, and caching. Usage: engine = ContextEngine(mcp_manager=mcp, redis=redis) # Assemble context for an LLM request result = await engine.assemble_context( project_id="proj-123", agent_id="agent-456", query="implement user authentication", model="claude-3-sonnet", system_prompt="You are an expert developer.", knowledge_query="authentication best practices", ) # Use the assembled context print(result.content) print(f"Tokens: {result.total_tokens}") """ def __init__( self, mcp_manager: "MCPClientManager | None" = None, redis: "Redis | None" = None, settings: ContextSettings | None = None, memory_source: "MemoryContextSource | None" = None, ) -> None: """ Initialize the context engine. Args: mcp_manager: MCP client manager for LLM Gateway/Knowledge Base redis: Redis connection for caching settings: Context settings memory_source: Optional memory context source for agent memory """ self._mcp = mcp_manager self._settings = settings or get_context_settings() self._memory_source = memory_source # Initialize components self._calculator = TokenCalculator(mcp_manager=mcp_manager) self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings) self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator) self._compressor = ContextCompressor(calculator=self._calculator) self._allocator = BudgetAllocator(self._settings) self._cache = ContextCache(redis=redis, settings=self._settings) # Pipeline for assembly self._pipeline = ContextPipeline( mcp_manager=mcp_manager, settings=self._settings, calculator=self._calculator, scorer=self._scorer, ranker=self._ranker, compressor=self._compressor, ) def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: """ Set MCP manager for all components. Args: mcp_manager: MCP client manager """ self._mcp = mcp_manager self._calculator.set_mcp_manager(mcp_manager) self._scorer.set_mcp_manager(mcp_manager) self._pipeline.set_mcp_manager(mcp_manager) def set_redis(self, redis: "Redis") -> None: """ Set Redis connection for caching. Args: redis: Redis connection """ self._cache.set_redis(redis) def set_memory_source(self, memory_source: "MemoryContextSource") -> None: """ Set memory context source for agent memory integration. Args: memory_source: Memory context source """ self._memory_source = memory_source async def assemble_context( self, project_id: str, agent_id: str, query: str, model: str, max_tokens: int | None = None, system_prompt: str | None = None, task_description: str | None = None, knowledge_query: str | None = None, knowledge_limit: int = 10, memory_query: str | None = None, memory_limit: int = 20, session_id: str | None = None, agent_type_id: str | None = None, conversation_history: list[dict[str, str]] | None = None, tool_results: list[dict[str, Any]] | None = None, custom_contexts: list[BaseContext] | None = None, custom_budget: TokenBudget | None = None, compress: bool = True, format_output: bool = True, use_cache: bool = True, ) -> AssembledContext: """ Assemble optimized context for an LLM request. This is the main entry point for context management. It gathers context from various sources, scores and ranks them, compresses if needed, and formats for the target model. Args: project_id: Project identifier agent_id: Agent identifier query: User's query or current request model: Target model name max_tokens: Maximum context tokens (uses model default if None) system_prompt: System prompt/instructions task_description: Current task description knowledge_query: Query for knowledge base search knowledge_limit: Max number of knowledge results memory_query: Query for agent memory search memory_limit: Max number of memory results session_id: Session ID for working memory access agent_type_id: Agent type ID for procedural memory conversation_history: List of {"role": str, "content": str} tool_results: List of tool results to include custom_contexts: Additional custom contexts custom_budget: Custom token budget compress: Whether to apply compression format_output: Whether to format for the model use_cache: Whether to use caching Returns: AssembledContext with optimized content Raises: AssemblyTimeoutError: If assembly exceeds timeout BudgetExceededError: If context exceeds budget """ # Gather all contexts contexts: list[BaseContext] = [] # 1. System context if system_prompt: contexts.append( SystemContext( content=system_prompt, source="system_prompt", ) ) # 2. Task context if task_description: contexts.append( TaskContext( content=task_description, source=f"task:{project_id}:{agent_id}", ) ) # 3. Knowledge context from Knowledge Base if knowledge_query and self._mcp: knowledge_contexts = await self._fetch_knowledge( project_id=project_id, agent_id=agent_id, query=knowledge_query, limit=knowledge_limit, ) contexts.extend(knowledge_contexts) # 4. Memory context from Agent Memory System if memory_query and self._memory_source: memory_contexts = await self._fetch_memory( project_id=project_id, agent_id=agent_id, query=memory_query, limit=memory_limit, session_id=session_id, agent_type_id=agent_type_id, ) contexts.extend(memory_contexts) # 5. Conversation history if conversation_history: contexts.extend(self._convert_conversation(conversation_history)) # 6. Tool results if tool_results: contexts.extend(self._convert_tool_results(tool_results)) # 7. Custom contexts if custom_contexts: contexts.extend(custom_contexts) # Check cache if enabled fingerprint: str | None = None if use_cache and self._cache.is_enabled: # Include project_id and agent_id for tenant isolation fingerprint = self._cache.compute_fingerprint( contexts, query, model, project_id=project_id, agent_id=agent_id ) cached = await self._cache.get_assembled(fingerprint) if cached: logger.debug(f"Cache hit for context assembly: {fingerprint}") return cached # Run assembly pipeline result = await self._pipeline.assemble( contexts=contexts, query=query, model=model, max_tokens=max_tokens, custom_budget=custom_budget, compress=compress, format_output=format_output, ) # Cache result if enabled (reuse fingerprint computed above) if use_cache and self._cache.is_enabled and fingerprint is not None: await self._cache.set_assembled(fingerprint, result) return result async def _fetch_knowledge( self, project_id: str, agent_id: str, query: str, limit: int = 10, ) -> list[KnowledgeContext]: """ Fetch relevant knowledge from Knowledge Base via MCP. Args: project_id: Project identifier agent_id: Agent identifier query: Search query limit: Maximum results Returns: List of KnowledgeContext instances """ if not self._mcp: return [] try: result = await self._mcp.call_tool( "knowledge-base", "search_knowledge", { "project_id": project_id, "agent_id": agent_id, "query": query, "search_type": "hybrid", "limit": limit, }, ) # Check both ToolResult.success AND response success if not result.success: logger.warning(f"Knowledge search failed: {result.error}") return [] if not isinstance(result.data, dict) or not result.data.get( "success", True ): logger.warning("Knowledge search returned unsuccessful response") return [] contexts = [] results = result.data.get("results", []) for chunk in results: contexts.append( KnowledgeContext( content=chunk.get("content", ""), source=chunk.get("source_path", "unknown"), relevance_score=chunk.get("score", 0.0), metadata={ "chunk_id": chunk.get( "id" ), # Server returns 'id' not 'chunk_id' "document_id": chunk.get("document_id"), }, ) ) logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}") return contexts except Exception as e: logger.warning(f"Failed to fetch knowledge: {e}") return [] async def _fetch_memory( self, project_id: str, agent_id: str, query: str, limit: int = 20, session_id: str | None = None, agent_type_id: str | None = None, ) -> list[MemoryContext]: """ Fetch relevant memories from Agent Memory System. Args: project_id: Project identifier agent_id: Agent identifier query: Search query limit: Maximum results session_id: Session ID for working memory agent_type_id: Agent type ID for procedural memory Returns: List of MemoryContext instances """ if not self._memory_source: return [] try: # Import here to avoid circular imports # Configure fetch limits from app.services.memory.integration.context_source import MemoryFetchConfig config = MemoryFetchConfig( working_limit=min(limit // 4, 5), episodic_limit=min(limit // 2, 10), semantic_limit=min(limit // 2, 10), procedural_limit=min(limit // 4, 5), include_working=session_id is not None, ) result = await self._memory_source.fetch_context( query=query, project_id=UUID(project_id), agent_instance_id=UUID(agent_id) if agent_id else None, agent_type_id=UUID(agent_type_id) if agent_type_id else None, session_id=session_id, config=config, ) logger.debug( f"Fetched {len(result.contexts)} memory contexts for query: {query}, " f"by_type: {result.by_type}" ) return result.contexts[:limit] except Exception as e: logger.warning(f"Failed to fetch memory: {e}") return [] def _convert_conversation( self, history: list[dict[str, str]], ) -> list[ConversationContext]: """ Convert conversation history to ConversationContext instances. Args: history: List of {"role": str, "content": str} Returns: List of ConversationContext instances """ contexts = [] for i, turn in enumerate(history): role_str = turn.get("role", "user").lower() role = ( MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER ) contexts.append( ConversationContext( content=turn.get("content", ""), source=f"conversation:{i}", role=role, metadata={"role": role_str, "turn": i}, ) ) return contexts def _convert_tool_results( self, results: list[dict[str, Any]], ) -> list[ToolContext]: """ Convert tool results to ToolContext instances. Args: results: List of tool result dictionaries Returns: List of ToolContext instances """ contexts = [] for result in results: tool_name = result.get("tool_name", "unknown") content = result.get("content", result.get("result", "")) # Handle dict content if isinstance(content, dict): import json content = json.dumps(content, indent=2) contexts.append( ToolContext( content=str(content), source=f"tool:{tool_name}", metadata={ "tool_name": tool_name, "status": result.get("status", "success"), }, ) ) return contexts async def get_budget_for_model( self, model: str, max_tokens: int | None = None, ) -> TokenBudget: """ Get the token budget for a specific model. Args: model: Model name max_tokens: Optional max tokens override Returns: TokenBudget instance """ if max_tokens: return self._allocator.create_budget(max_tokens) return self._allocator.create_budget_for_model(model) async def count_tokens( self, content: str, model: str | None = None, ) -> int: """ Count tokens in content. Args: content: Content to count model: Model for model-specific tokenization Returns: Token count """ # Check cache first cached = await self._cache.get_token_count(content, model) if cached is not None: return cached count = await self._calculator.count_tokens(content, model) # Cache the result await self._cache.set_token_count(content, count, model) return count async def invalidate_cache( self, project_id: str | None = None, pattern: str | None = None, ) -> int: """ Invalidate cache entries. Args: project_id: Invalidate all cache for a project pattern: Custom pattern to match Returns: Number of entries invalidated """ if pattern: return await self._cache.invalidate(pattern) elif project_id: return await self._cache.invalidate(f"*{project_id}*") else: return await self._cache.clear_all() async def get_stats(self) -> dict[str, Any]: """ Get engine statistics. Returns: Dictionary with engine stats """ return { "cache": await self._cache.get_stats(), "settings": { "compression_threshold": self._settings.compression_threshold, "max_assembly_time_ms": self._settings.max_assembly_time_ms, "cache_enabled": self._settings.cache_enabled, }, } # Convenience factory function def create_context_engine( mcp_manager: "MCPClientManager | None" = None, redis: "Redis | None" = None, settings: ContextSettings | None = None, memory_source: "MemoryContextSource | None" = None, ) -> ContextEngine: """ Create a context engine instance. Args: mcp_manager: MCP client manager redis: Redis connection settings: Context settings memory_source: Optional memory context source Returns: Configured ContextEngine instance """ return ContextEngine( mcp_manager=mcp_manager, redis=redis, settings=settings, memory_source=memory_source, )