## Changes ### New Context Type - Add MEMORY to ContextType enum for agent memory context - Create MemoryContext class with subtypes (working, episodic, semantic, procedural) - Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory ### Memory Context Source - MemoryContextSource service fetches relevant memories for context assembly - Configurable fetch limits per memory type - Parallel fetching from all memory types ### Agent Lifecycle Hooks - AgentLifecycleManager handles spawn, pause, resume, terminate events - spawn: Initialize working memory with optional initial state - pause: Create checkpoint of working memory - resume: Restore from checkpoint - terminate: Consolidate working memory to episodic memory - LifecycleHooks for custom extension points ### Context Engine Integration - Add memory_query parameter to assemble_context() - Add session_id and agent_type_id for memory scoping - Memory budget allocation (15% by default) - set_memory_source() for runtime configuration ### Tests - 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks - All 108 memory-related tests passing - mypy and ruff checks passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
583 lines
18 KiB
Python
583 lines
18 KiB
Python
"""
|
|
Context Management Engine.
|
|
|
|
Main orchestration layer for context assembly and optimization.
|
|
Provides a high-level API for assembling optimized context for LLM requests.
|
|
"""
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING, Any
|
|
from uuid import UUID
|
|
|
|
from .assembly import ContextPipeline
|
|
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
|
from .cache import ContextCache
|
|
from .compression import ContextCompressor
|
|
from .config import ContextSettings, get_context_settings
|
|
from .prioritization import ContextRanker
|
|
from .scoring import CompositeScorer
|
|
from .types import (
|
|
AssembledContext,
|
|
BaseContext,
|
|
ConversationContext,
|
|
KnowledgeContext,
|
|
MemoryContext,
|
|
MessageRole,
|
|
SystemContext,
|
|
TaskContext,
|
|
ToolContext,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from redis.asyncio import Redis
|
|
|
|
from app.services.mcp.client_manager import MCPClientManager
|
|
from app.services.memory.integration import MemoryContextSource
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ContextEngine:
|
|
"""
|
|
Main context management engine.
|
|
|
|
Provides high-level API for context assembly and optimization.
|
|
Integrates all components: scoring, ranking, compression, formatting, and caching.
|
|
|
|
Usage:
|
|
engine = ContextEngine(mcp_manager=mcp, redis=redis)
|
|
|
|
# Assemble context for an LLM request
|
|
result = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="implement user authentication",
|
|
model="claude-3-sonnet",
|
|
system_prompt="You are an expert developer.",
|
|
knowledge_query="authentication best practices",
|
|
)
|
|
|
|
# Use the assembled context
|
|
print(result.content)
|
|
print(f"Tokens: {result.total_tokens}")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mcp_manager: "MCPClientManager | None" = None,
|
|
redis: "Redis | None" = None,
|
|
settings: ContextSettings | None = None,
|
|
memory_source: "MemoryContextSource | None" = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the context engine.
|
|
|
|
Args:
|
|
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
|
redis: Redis connection for caching
|
|
settings: Context settings
|
|
memory_source: Optional memory context source for agent memory
|
|
"""
|
|
self._mcp = mcp_manager
|
|
self._settings = settings or get_context_settings()
|
|
self._memory_source = memory_source
|
|
|
|
# Initialize components
|
|
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
|
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
|
|
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
|
|
self._compressor = ContextCompressor(calculator=self._calculator)
|
|
self._allocator = BudgetAllocator(self._settings)
|
|
self._cache = ContextCache(redis=redis, settings=self._settings)
|
|
|
|
# Pipeline for assembly
|
|
self._pipeline = ContextPipeline(
|
|
mcp_manager=mcp_manager,
|
|
settings=self._settings,
|
|
calculator=self._calculator,
|
|
scorer=self._scorer,
|
|
ranker=self._ranker,
|
|
compressor=self._compressor,
|
|
)
|
|
|
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
|
"""
|
|
Set MCP manager for all components.
|
|
|
|
Args:
|
|
mcp_manager: MCP client manager
|
|
"""
|
|
self._mcp = mcp_manager
|
|
self._calculator.set_mcp_manager(mcp_manager)
|
|
self._scorer.set_mcp_manager(mcp_manager)
|
|
self._pipeline.set_mcp_manager(mcp_manager)
|
|
|
|
def set_redis(self, redis: "Redis") -> None:
|
|
"""
|
|
Set Redis connection for caching.
|
|
|
|
Args:
|
|
redis: Redis connection
|
|
"""
|
|
self._cache.set_redis(redis)
|
|
|
|
def set_memory_source(self, memory_source: "MemoryContextSource") -> None:
|
|
"""
|
|
Set memory context source for agent memory integration.
|
|
|
|
Args:
|
|
memory_source: Memory context source
|
|
"""
|
|
self._memory_source = memory_source
|
|
|
|
async def assemble_context(
|
|
self,
|
|
project_id: str,
|
|
agent_id: str,
|
|
query: str,
|
|
model: str,
|
|
max_tokens: int | None = None,
|
|
system_prompt: str | None = None,
|
|
task_description: str | None = None,
|
|
knowledge_query: str | None = None,
|
|
knowledge_limit: int = 10,
|
|
memory_query: str | None = None,
|
|
memory_limit: int = 20,
|
|
session_id: str | None = None,
|
|
agent_type_id: str | None = None,
|
|
conversation_history: list[dict[str, str]] | None = None,
|
|
tool_results: list[dict[str, Any]] | None = None,
|
|
custom_contexts: list[BaseContext] | None = None,
|
|
custom_budget: TokenBudget | None = None,
|
|
compress: bool = True,
|
|
format_output: bool = True,
|
|
use_cache: bool = True,
|
|
) -> AssembledContext:
|
|
"""
|
|
Assemble optimized context for an LLM request.
|
|
|
|
This is the main entry point for context management.
|
|
It gathers context from various sources, scores and ranks them,
|
|
compresses if needed, and formats for the target model.
|
|
|
|
Args:
|
|
project_id: Project identifier
|
|
agent_id: Agent identifier
|
|
query: User's query or current request
|
|
model: Target model name
|
|
max_tokens: Maximum context tokens (uses model default if None)
|
|
system_prompt: System prompt/instructions
|
|
task_description: Current task description
|
|
knowledge_query: Query for knowledge base search
|
|
knowledge_limit: Max number of knowledge results
|
|
memory_query: Query for agent memory search
|
|
memory_limit: Max number of memory results
|
|
session_id: Session ID for working memory access
|
|
agent_type_id: Agent type ID for procedural memory
|
|
conversation_history: List of {"role": str, "content": str}
|
|
tool_results: List of tool results to include
|
|
custom_contexts: Additional custom contexts
|
|
custom_budget: Custom token budget
|
|
compress: Whether to apply compression
|
|
format_output: Whether to format for the model
|
|
use_cache: Whether to use caching
|
|
|
|
Returns:
|
|
AssembledContext with optimized content
|
|
|
|
Raises:
|
|
AssemblyTimeoutError: If assembly exceeds timeout
|
|
BudgetExceededError: If context exceeds budget
|
|
"""
|
|
# Gather all contexts
|
|
contexts: list[BaseContext] = []
|
|
|
|
# 1. System context
|
|
if system_prompt:
|
|
contexts.append(
|
|
SystemContext(
|
|
content=system_prompt,
|
|
source="system_prompt",
|
|
)
|
|
)
|
|
|
|
# 2. Task context
|
|
if task_description:
|
|
contexts.append(
|
|
TaskContext(
|
|
content=task_description,
|
|
source=f"task:{project_id}:{agent_id}",
|
|
)
|
|
)
|
|
|
|
# 3. Knowledge context from Knowledge Base
|
|
if knowledge_query and self._mcp:
|
|
knowledge_contexts = await self._fetch_knowledge(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
query=knowledge_query,
|
|
limit=knowledge_limit,
|
|
)
|
|
contexts.extend(knowledge_contexts)
|
|
|
|
# 4. Memory context from Agent Memory System
|
|
if memory_query and self._memory_source:
|
|
memory_contexts = await self._fetch_memory(
|
|
project_id=project_id,
|
|
agent_id=agent_id,
|
|
query=memory_query,
|
|
limit=memory_limit,
|
|
session_id=session_id,
|
|
agent_type_id=agent_type_id,
|
|
)
|
|
contexts.extend(memory_contexts)
|
|
|
|
# 5. Conversation history
|
|
if conversation_history:
|
|
contexts.extend(self._convert_conversation(conversation_history))
|
|
|
|
# 6. Tool results
|
|
if tool_results:
|
|
contexts.extend(self._convert_tool_results(tool_results))
|
|
|
|
# 7. Custom contexts
|
|
if custom_contexts:
|
|
contexts.extend(custom_contexts)
|
|
|
|
# Check cache if enabled
|
|
fingerprint: str | None = None
|
|
if use_cache and self._cache.is_enabled:
|
|
# Include project_id and agent_id for tenant isolation
|
|
fingerprint = self._cache.compute_fingerprint(
|
|
contexts, query, model, project_id=project_id, agent_id=agent_id
|
|
)
|
|
cached = await self._cache.get_assembled(fingerprint)
|
|
if cached:
|
|
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
|
return cached
|
|
|
|
# Run assembly pipeline
|
|
result = await self._pipeline.assemble(
|
|
contexts=contexts,
|
|
query=query,
|
|
model=model,
|
|
max_tokens=max_tokens,
|
|
custom_budget=custom_budget,
|
|
compress=compress,
|
|
format_output=format_output,
|
|
)
|
|
|
|
# Cache result if enabled (reuse fingerprint computed above)
|
|
if use_cache and self._cache.is_enabled and fingerprint is not None:
|
|
await self._cache.set_assembled(fingerprint, result)
|
|
|
|
return result
|
|
|
|
async def _fetch_knowledge(
|
|
self,
|
|
project_id: str,
|
|
agent_id: str,
|
|
query: str,
|
|
limit: int = 10,
|
|
) -> list[KnowledgeContext]:
|
|
"""
|
|
Fetch relevant knowledge from Knowledge Base via MCP.
|
|
|
|
Args:
|
|
project_id: Project identifier
|
|
agent_id: Agent identifier
|
|
query: Search query
|
|
limit: Maximum results
|
|
|
|
Returns:
|
|
List of KnowledgeContext instances
|
|
"""
|
|
if not self._mcp:
|
|
return []
|
|
|
|
try:
|
|
result = await self._mcp.call_tool(
|
|
"knowledge-base",
|
|
"search_knowledge",
|
|
{
|
|
"project_id": project_id,
|
|
"agent_id": agent_id,
|
|
"query": query,
|
|
"search_type": "hybrid",
|
|
"limit": limit,
|
|
},
|
|
)
|
|
|
|
# Check both ToolResult.success AND response success
|
|
if not result.success:
|
|
logger.warning(f"Knowledge search failed: {result.error}")
|
|
return []
|
|
|
|
if not isinstance(result.data, dict) or not result.data.get(
|
|
"success", True
|
|
):
|
|
logger.warning("Knowledge search returned unsuccessful response")
|
|
return []
|
|
|
|
contexts = []
|
|
results = result.data.get("results", [])
|
|
for chunk in results:
|
|
contexts.append(
|
|
KnowledgeContext(
|
|
content=chunk.get("content", ""),
|
|
source=chunk.get("source_path", "unknown"),
|
|
relevance_score=chunk.get("score", 0.0),
|
|
metadata={
|
|
"chunk_id": chunk.get(
|
|
"id"
|
|
), # Server returns 'id' not 'chunk_id'
|
|
"document_id": chunk.get("document_id"),
|
|
},
|
|
)
|
|
)
|
|
|
|
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
|
|
return contexts
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch knowledge: {e}")
|
|
return []
|
|
|
|
async def _fetch_memory(
|
|
self,
|
|
project_id: str,
|
|
agent_id: str,
|
|
query: str,
|
|
limit: int = 20,
|
|
session_id: str | None = None,
|
|
agent_type_id: str | None = None,
|
|
) -> list[MemoryContext]:
|
|
"""
|
|
Fetch relevant memories from Agent Memory System.
|
|
|
|
Args:
|
|
project_id: Project identifier
|
|
agent_id: Agent identifier
|
|
query: Search query
|
|
limit: Maximum results
|
|
session_id: Session ID for working memory
|
|
agent_type_id: Agent type ID for procedural memory
|
|
|
|
Returns:
|
|
List of MemoryContext instances
|
|
"""
|
|
if not self._memory_source:
|
|
return []
|
|
|
|
try:
|
|
# Import here to avoid circular imports
|
|
|
|
# Configure fetch limits
|
|
from app.services.memory.integration.context_source import MemoryFetchConfig
|
|
|
|
config = MemoryFetchConfig(
|
|
working_limit=min(limit // 4, 5),
|
|
episodic_limit=min(limit // 2, 10),
|
|
semantic_limit=min(limit // 2, 10),
|
|
procedural_limit=min(limit // 4, 5),
|
|
include_working=session_id is not None,
|
|
)
|
|
|
|
result = await self._memory_source.fetch_context(
|
|
query=query,
|
|
project_id=UUID(project_id),
|
|
agent_instance_id=UUID(agent_id) if agent_id else None,
|
|
agent_type_id=UUID(agent_type_id) if agent_type_id else None,
|
|
session_id=session_id,
|
|
config=config,
|
|
)
|
|
|
|
logger.debug(
|
|
f"Fetched {len(result.contexts)} memory contexts for query: {query}, "
|
|
f"by_type: {result.by_type}"
|
|
)
|
|
return result.contexts[:limit]
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch memory: {e}")
|
|
return []
|
|
|
|
def _convert_conversation(
|
|
self,
|
|
history: list[dict[str, str]],
|
|
) -> list[ConversationContext]:
|
|
"""
|
|
Convert conversation history to ConversationContext instances.
|
|
|
|
Args:
|
|
history: List of {"role": str, "content": str}
|
|
|
|
Returns:
|
|
List of ConversationContext instances
|
|
"""
|
|
contexts = []
|
|
for i, turn in enumerate(history):
|
|
role_str = turn.get("role", "user").lower()
|
|
role = (
|
|
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
|
)
|
|
|
|
contexts.append(
|
|
ConversationContext(
|
|
content=turn.get("content", ""),
|
|
source=f"conversation:{i}",
|
|
role=role,
|
|
metadata={"role": role_str, "turn": i},
|
|
)
|
|
)
|
|
|
|
return contexts
|
|
|
|
def _convert_tool_results(
|
|
self,
|
|
results: list[dict[str, Any]],
|
|
) -> list[ToolContext]:
|
|
"""
|
|
Convert tool results to ToolContext instances.
|
|
|
|
Args:
|
|
results: List of tool result dictionaries
|
|
|
|
Returns:
|
|
List of ToolContext instances
|
|
"""
|
|
contexts = []
|
|
for result in results:
|
|
tool_name = result.get("tool_name", "unknown")
|
|
content = result.get("content", result.get("result", ""))
|
|
|
|
# Handle dict content
|
|
if isinstance(content, dict):
|
|
import json
|
|
|
|
content = json.dumps(content, indent=2)
|
|
|
|
contexts.append(
|
|
ToolContext(
|
|
content=str(content),
|
|
source=f"tool:{tool_name}",
|
|
metadata={
|
|
"tool_name": tool_name,
|
|
"status": result.get("status", "success"),
|
|
},
|
|
)
|
|
)
|
|
|
|
return contexts
|
|
|
|
async def get_budget_for_model(
|
|
self,
|
|
model: str,
|
|
max_tokens: int | None = None,
|
|
) -> TokenBudget:
|
|
"""
|
|
Get the token budget for a specific model.
|
|
|
|
Args:
|
|
model: Model name
|
|
max_tokens: Optional max tokens override
|
|
|
|
Returns:
|
|
TokenBudget instance
|
|
"""
|
|
if max_tokens:
|
|
return self._allocator.create_budget(max_tokens)
|
|
return self._allocator.create_budget_for_model(model)
|
|
|
|
async def count_tokens(
|
|
self,
|
|
content: str,
|
|
model: str | None = None,
|
|
) -> int:
|
|
"""
|
|
Count tokens in content.
|
|
|
|
Args:
|
|
content: Content to count
|
|
model: Model for model-specific tokenization
|
|
|
|
Returns:
|
|
Token count
|
|
"""
|
|
# Check cache first
|
|
cached = await self._cache.get_token_count(content, model)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
count = await self._calculator.count_tokens(content, model)
|
|
|
|
# Cache the result
|
|
await self._cache.set_token_count(content, count, model)
|
|
|
|
return count
|
|
|
|
async def invalidate_cache(
|
|
self,
|
|
project_id: str | None = None,
|
|
pattern: str | None = None,
|
|
) -> int:
|
|
"""
|
|
Invalidate cache entries.
|
|
|
|
Args:
|
|
project_id: Invalidate all cache for a project
|
|
pattern: Custom pattern to match
|
|
|
|
Returns:
|
|
Number of entries invalidated
|
|
"""
|
|
if pattern:
|
|
return await self._cache.invalidate(pattern)
|
|
elif project_id:
|
|
return await self._cache.invalidate(f"*{project_id}*")
|
|
else:
|
|
return await self._cache.clear_all()
|
|
|
|
async def get_stats(self) -> dict[str, Any]:
|
|
"""
|
|
Get engine statistics.
|
|
|
|
Returns:
|
|
Dictionary with engine stats
|
|
"""
|
|
return {
|
|
"cache": await self._cache.get_stats(),
|
|
"settings": {
|
|
"compression_threshold": self._settings.compression_threshold,
|
|
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
|
|
"cache_enabled": self._settings.cache_enabled,
|
|
},
|
|
}
|
|
|
|
|
|
# Convenience factory function
|
|
def create_context_engine(
|
|
mcp_manager: "MCPClientManager | None" = None,
|
|
redis: "Redis | None" = None,
|
|
settings: ContextSettings | None = None,
|
|
memory_source: "MemoryContextSource | None" = None,
|
|
) -> ContextEngine:
|
|
"""
|
|
Create a context engine instance.
|
|
|
|
Args:
|
|
mcp_manager: MCP client manager
|
|
redis: Redis connection
|
|
settings: Context settings
|
|
memory_source: Optional memory context source
|
|
|
|
Returns:
|
|
Configured ContextEngine instance
|
|
"""
|
|
return ContextEngine(
|
|
mcp_manager=mcp_manager,
|
|
redis=redis,
|
|
settings=settings,
|
|
memory_source=memory_source,
|
|
)
|