Files
fast-next-template/backend/app/services/context/engine.py
Felipe Cardoso 30e5c68304 feat(memory): integrate memory system with context engine (#97)
## Changes

### New Context Type
- Add MEMORY to ContextType enum for agent memory context
- Create MemoryContext class with subtypes (working, episodic, semantic, procedural)
- Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory

### Memory Context Source
- MemoryContextSource service fetches relevant memories for context assembly
- Configurable fetch limits per memory type
- Parallel fetching from all memory types

### Agent Lifecycle Hooks
- AgentLifecycleManager handles spawn, pause, resume, terminate events
- spawn: Initialize working memory with optional initial state
- pause: Create checkpoint of working memory
- resume: Restore from checkpoint
- terminate: Consolidate working memory to episodic memory
- LifecycleHooks for custom extension points

### Context Engine Integration
- Add memory_query parameter to assemble_context()
- Add session_id and agent_type_id for memory scoping
- Memory budget allocation (15% by default)
- set_memory_source() for runtime configuration

### Tests
- 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks
- All 108 memory-related tests passing
- mypy and ruff checks passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 03:49:22 +01:00

583 lines
18 KiB
Python

"""
Context Management Engine.
Main orchestration layer for context assembly and optimization.
Provides a high-level API for assembling optimized context for LLM requests.
"""
import logging
from typing import TYPE_CHECKING, Any
from uuid import UUID
from .assembly import ContextPipeline
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
from .cache import ContextCache
from .compression import ContextCompressor
from .config import ContextSettings, get_context_settings
from .prioritization import ContextRanker
from .scoring import CompositeScorer
from .types import (
AssembledContext,
BaseContext,
ConversationContext,
KnowledgeContext,
MemoryContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)
if TYPE_CHECKING:
from redis.asyncio import Redis
from app.services.mcp.client_manager import MCPClientManager
from app.services.memory.integration import MemoryContextSource
logger = logging.getLogger(__name__)
class ContextEngine:
"""
Main context management engine.
Provides high-level API for context assembly and optimization.
Integrates all components: scoring, ranking, compression, formatting, and caching.
Usage:
engine = ContextEngine(mcp_manager=mcp, redis=redis)
# Assemble context for an LLM request
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="implement user authentication",
model="claude-3-sonnet",
system_prompt="You are an expert developer.",
knowledge_query="authentication best practices",
)
# Use the assembled context
print(result.content)
print(f"Tokens: {result.total_tokens}")
"""
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
memory_source: "MemoryContextSource | None" = None,
) -> None:
"""
Initialize the context engine.
Args:
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
redis: Redis connection for caching
settings: Context settings
memory_source: Optional memory context source for agent memory
"""
self._mcp = mcp_manager
self._settings = settings or get_context_settings()
self._memory_source = memory_source
# Initialize components
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
self._compressor = ContextCompressor(calculator=self._calculator)
self._allocator = BudgetAllocator(self._settings)
self._cache = ContextCache(redis=redis, settings=self._settings)
# Pipeline for assembly
self._pipeline = ContextPipeline(
mcp_manager=mcp_manager,
settings=self._settings,
calculator=self._calculator,
scorer=self._scorer,
ranker=self._ranker,
compressor=self._compressor,
)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""
Set MCP manager for all components.
Args:
mcp_manager: MCP client manager
"""
self._mcp = mcp_manager
self._calculator.set_mcp_manager(mcp_manager)
self._scorer.set_mcp_manager(mcp_manager)
self._pipeline.set_mcp_manager(mcp_manager)
def set_redis(self, redis: "Redis") -> None:
"""
Set Redis connection for caching.
Args:
redis: Redis connection
"""
self._cache.set_redis(redis)
def set_memory_source(self, memory_source: "MemoryContextSource") -> None:
"""
Set memory context source for agent memory integration.
Args:
memory_source: Memory context source
"""
self._memory_source = memory_source
async def assemble_context(
self,
project_id: str,
agent_id: str,
query: str,
model: str,
max_tokens: int | None = None,
system_prompt: str | None = None,
task_description: str | None = None,
knowledge_query: str | None = None,
knowledge_limit: int = 10,
memory_query: str | None = None,
memory_limit: int = 20,
session_id: str | None = None,
agent_type_id: str | None = None,
conversation_history: list[dict[str, str]] | None = None,
tool_results: list[dict[str, Any]] | None = None,
custom_contexts: list[BaseContext] | None = None,
custom_budget: TokenBudget | None = None,
compress: bool = True,
format_output: bool = True,
use_cache: bool = True,
) -> AssembledContext:
"""
Assemble optimized context for an LLM request.
This is the main entry point for context management.
It gathers context from various sources, scores and ranks them,
compresses if needed, and formats for the target model.
Args:
project_id: Project identifier
agent_id: Agent identifier
query: User's query or current request
model: Target model name
max_tokens: Maximum context tokens (uses model default if None)
system_prompt: System prompt/instructions
task_description: Current task description
knowledge_query: Query for knowledge base search
knowledge_limit: Max number of knowledge results
memory_query: Query for agent memory search
memory_limit: Max number of memory results
session_id: Session ID for working memory access
agent_type_id: Agent type ID for procedural memory
conversation_history: List of {"role": str, "content": str}
tool_results: List of tool results to include
custom_contexts: Additional custom contexts
custom_budget: Custom token budget
compress: Whether to apply compression
format_output: Whether to format for the model
use_cache: Whether to use caching
Returns:
AssembledContext with optimized content
Raises:
AssemblyTimeoutError: If assembly exceeds timeout
BudgetExceededError: If context exceeds budget
"""
# Gather all contexts
contexts: list[BaseContext] = []
# 1. System context
if system_prompt:
contexts.append(
SystemContext(
content=system_prompt,
source="system_prompt",
)
)
# 2. Task context
if task_description:
contexts.append(
TaskContext(
content=task_description,
source=f"task:{project_id}:{agent_id}",
)
)
# 3. Knowledge context from Knowledge Base
if knowledge_query and self._mcp:
knowledge_contexts = await self._fetch_knowledge(
project_id=project_id,
agent_id=agent_id,
query=knowledge_query,
limit=knowledge_limit,
)
contexts.extend(knowledge_contexts)
# 4. Memory context from Agent Memory System
if memory_query and self._memory_source:
memory_contexts = await self._fetch_memory(
project_id=project_id,
agent_id=agent_id,
query=memory_query,
limit=memory_limit,
session_id=session_id,
agent_type_id=agent_type_id,
)
contexts.extend(memory_contexts)
# 5. Conversation history
if conversation_history:
contexts.extend(self._convert_conversation(conversation_history))
# 6. Tool results
if tool_results:
contexts.extend(self._convert_tool_results(tool_results))
# 7. Custom contexts
if custom_contexts:
contexts.extend(custom_contexts)
# Check cache if enabled
fingerprint: str | None = None
if use_cache and self._cache.is_enabled:
# Include project_id and agent_id for tenant isolation
fingerprint = self._cache.compute_fingerprint(
contexts, query, model, project_id=project_id, agent_id=agent_id
)
cached = await self._cache.get_assembled(fingerprint)
if cached:
logger.debug(f"Cache hit for context assembly: {fingerprint}")
return cached
# Run assembly pipeline
result = await self._pipeline.assemble(
contexts=contexts,
query=query,
model=model,
max_tokens=max_tokens,
custom_budget=custom_budget,
compress=compress,
format_output=format_output,
)
# Cache result if enabled (reuse fingerprint computed above)
if use_cache and self._cache.is_enabled and fingerprint is not None:
await self._cache.set_assembled(fingerprint, result)
return result
async def _fetch_knowledge(
self,
project_id: str,
agent_id: str,
query: str,
limit: int = 10,
) -> list[KnowledgeContext]:
"""
Fetch relevant knowledge from Knowledge Base via MCP.
Args:
project_id: Project identifier
agent_id: Agent identifier
query: Search query
limit: Maximum results
Returns:
List of KnowledgeContext instances
"""
if not self._mcp:
return []
try:
result = await self._mcp.call_tool(
"knowledge-base",
"search_knowledge",
{
"project_id": project_id,
"agent_id": agent_id,
"query": query,
"search_type": "hybrid",
"limit": limit,
},
)
# Check both ToolResult.success AND response success
if not result.success:
logger.warning(f"Knowledge search failed: {result.error}")
return []
if not isinstance(result.data, dict) or not result.data.get(
"success", True
):
logger.warning("Knowledge search returned unsuccessful response")
return []
contexts = []
results = result.data.get("results", [])
for chunk in results:
contexts.append(
KnowledgeContext(
content=chunk.get("content", ""),
source=chunk.get("source_path", "unknown"),
relevance_score=chunk.get("score", 0.0),
metadata={
"chunk_id": chunk.get(
"id"
), # Server returns 'id' not 'chunk_id'
"document_id": chunk.get("document_id"),
},
)
)
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
return contexts
except Exception as e:
logger.warning(f"Failed to fetch knowledge: {e}")
return []
async def _fetch_memory(
self,
project_id: str,
agent_id: str,
query: str,
limit: int = 20,
session_id: str | None = None,
agent_type_id: str | None = None,
) -> list[MemoryContext]:
"""
Fetch relevant memories from Agent Memory System.
Args:
project_id: Project identifier
agent_id: Agent identifier
query: Search query
limit: Maximum results
session_id: Session ID for working memory
agent_type_id: Agent type ID for procedural memory
Returns:
List of MemoryContext instances
"""
if not self._memory_source:
return []
try:
# Import here to avoid circular imports
# Configure fetch limits
from app.services.memory.integration.context_source import MemoryFetchConfig
config = MemoryFetchConfig(
working_limit=min(limit // 4, 5),
episodic_limit=min(limit // 2, 10),
semantic_limit=min(limit // 2, 10),
procedural_limit=min(limit // 4, 5),
include_working=session_id is not None,
)
result = await self._memory_source.fetch_context(
query=query,
project_id=UUID(project_id),
agent_instance_id=UUID(agent_id) if agent_id else None,
agent_type_id=UUID(agent_type_id) if agent_type_id else None,
session_id=session_id,
config=config,
)
logger.debug(
f"Fetched {len(result.contexts)} memory contexts for query: {query}, "
f"by_type: {result.by_type}"
)
return result.contexts[:limit]
except Exception as e:
logger.warning(f"Failed to fetch memory: {e}")
return []
def _convert_conversation(
self,
history: list[dict[str, str]],
) -> list[ConversationContext]:
"""
Convert conversation history to ConversationContext instances.
Args:
history: List of {"role": str, "content": str}
Returns:
List of ConversationContext instances
"""
contexts = []
for i, turn in enumerate(history):
role_str = turn.get("role", "user").lower()
role = (
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
)
contexts.append(
ConversationContext(
content=turn.get("content", ""),
source=f"conversation:{i}",
role=role,
metadata={"role": role_str, "turn": i},
)
)
return contexts
def _convert_tool_results(
self,
results: list[dict[str, Any]],
) -> list[ToolContext]:
"""
Convert tool results to ToolContext instances.
Args:
results: List of tool result dictionaries
Returns:
List of ToolContext instances
"""
contexts = []
for result in results:
tool_name = result.get("tool_name", "unknown")
content = result.get("content", result.get("result", ""))
# Handle dict content
if isinstance(content, dict):
import json
content = json.dumps(content, indent=2)
contexts.append(
ToolContext(
content=str(content),
source=f"tool:{tool_name}",
metadata={
"tool_name": tool_name,
"status": result.get("status", "success"),
},
)
)
return contexts
async def get_budget_for_model(
self,
model: str,
max_tokens: int | None = None,
) -> TokenBudget:
"""
Get the token budget for a specific model.
Args:
model: Model name
max_tokens: Optional max tokens override
Returns:
TokenBudget instance
"""
if max_tokens:
return self._allocator.create_budget(max_tokens)
return self._allocator.create_budget_for_model(model)
async def count_tokens(
self,
content: str,
model: str | None = None,
) -> int:
"""
Count tokens in content.
Args:
content: Content to count
model: Model for model-specific tokenization
Returns:
Token count
"""
# Check cache first
cached = await self._cache.get_token_count(content, model)
if cached is not None:
return cached
count = await self._calculator.count_tokens(content, model)
# Cache the result
await self._cache.set_token_count(content, count, model)
return count
async def invalidate_cache(
self,
project_id: str | None = None,
pattern: str | None = None,
) -> int:
"""
Invalidate cache entries.
Args:
project_id: Invalidate all cache for a project
pattern: Custom pattern to match
Returns:
Number of entries invalidated
"""
if pattern:
return await self._cache.invalidate(pattern)
elif project_id:
return await self._cache.invalidate(f"*{project_id}*")
else:
return await self._cache.clear_all()
async def get_stats(self) -> dict[str, Any]:
"""
Get engine statistics.
Returns:
Dictionary with engine stats
"""
return {
"cache": await self._cache.get_stats(),
"settings": {
"compression_threshold": self._settings.compression_threshold,
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
"cache_enabled": self._settings.cache_enabled,
},
}
# Convenience factory function
def create_context_engine(
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
memory_source: "MemoryContextSource | None" = None,
) -> ContextEngine:
"""
Create a context engine instance.
Args:
mcp_manager: MCP client manager
redis: Redis connection
settings: Context settings
memory_source: Optional memory context source
Returns:
Configured ContextEngine instance
"""
return ContextEngine(
mcp_manager=mcp_manager,
redis=redis,
settings=settings,
memory_source=memory_source,
)