feat(context): implement main ContextEngine with full integration (#85)

Phase 7 of Context Management Engine - Main Engine:

- Add ContextEngine as main orchestration class
- Integrate all components: calculator, scorer, ranker, compressor, cache
- Add high-level assemble_context() API with:
  - System prompt support
  - Task description support
  - Knowledge Base integration via MCP
  - Conversation history conversion
  - Tool results conversion
  - Custom contexts support
- Add helper methods:
  - get_budget_for_model()
  - count_tokens() with caching
  - invalidate_cache()
  - get_stats()
- Add create_context_engine() factory function

Tests: 26 new tests, 311 total context tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 02:44:40 +01:00
parent c2466ab401
commit 027ebfc332
3 changed files with 934 additions and 0 deletions

View File

@@ -88,6 +88,9 @@ from .adapters import (
# Cache
from .cache import ContextCache
# Engine
from .engine import ContextEngine, create_context_engine
# Prioritization
from .prioritization import (
ContextRanker,
@@ -137,6 +140,9 @@ __all__ = [
"TokenCalculator",
# Cache
"ContextCache",
# Engine
"ContextEngine",
"create_context_engine",
# Compression
"ContextCompressor",
"TruncationResult",

View File

@@ -0,0 +1,470 @@
"""
Context Management Engine.
Main orchestration layer for context assembly and optimization.
Provides a high-level API for assembling optimized context for LLM requests.
"""
import logging
from typing import TYPE_CHECKING, Any
from .adapters import get_adapter
from .assembly import ContextPipeline
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
from .cache import ContextCache
from .compression import ContextCompressor
from .config import ContextSettings, get_context_settings
from .prioritization import ContextRanker
from .scoring import CompositeScorer
from .types import (
AssembledContext,
BaseContext,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)
if TYPE_CHECKING:
from redis.asyncio import Redis
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
class ContextEngine:
"""
Main context management engine.
Provides high-level API for context assembly and optimization.
Integrates all components: scoring, ranking, compression, formatting, and caching.
Usage:
engine = ContextEngine(mcp_manager=mcp, redis=redis)
# Assemble context for an LLM request
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="implement user authentication",
model="claude-3-sonnet",
system_prompt="You are an expert developer.",
knowledge_query="authentication best practices",
)
# Use the assembled context
print(result.content)
print(f"Tokens: {result.total_tokens}")
"""
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize the context engine.
Args:
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
redis: Redis connection for caching
settings: Context settings
"""
self._mcp = mcp_manager
self._settings = settings or get_context_settings()
# Initialize components
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
self._scorer = CompositeScorer(
mcp_manager=mcp_manager, settings=self._settings
)
self._ranker = ContextRanker(
scorer=self._scorer, calculator=self._calculator
)
self._compressor = ContextCompressor(calculator=self._calculator)
self._allocator = BudgetAllocator(self._settings)
self._cache = ContextCache(redis=redis, settings=self._settings)
# Pipeline for assembly
self._pipeline = ContextPipeline(
mcp_manager=mcp_manager,
settings=self._settings,
calculator=self._calculator,
scorer=self._scorer,
ranker=self._ranker,
compressor=self._compressor,
)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""
Set MCP manager for all components.
Args:
mcp_manager: MCP client manager
"""
self._mcp = mcp_manager
self._calculator.set_mcp_manager(mcp_manager)
self._scorer.set_mcp_manager(mcp_manager)
self._pipeline.set_mcp_manager(mcp_manager)
def set_redis(self, redis: "Redis") -> None:
"""
Set Redis connection for caching.
Args:
redis: Redis connection
"""
self._cache.set_redis(redis)
async def assemble_context(
self,
project_id: str,
agent_id: str,
query: str,
model: str,
max_tokens: int | None = None,
system_prompt: str | None = None,
task_description: str | None = None,
knowledge_query: str | None = None,
knowledge_limit: int = 10,
conversation_history: list[dict[str, str]] | None = None,
tool_results: list[dict[str, Any]] | None = None,
custom_contexts: list[BaseContext] | None = None,
custom_budget: TokenBudget | None = None,
compress: bool = True,
format_output: bool = True,
use_cache: bool = True,
) -> AssembledContext:
"""
Assemble optimized context for an LLM request.
This is the main entry point for context management.
It gathers context from various sources, scores and ranks them,
compresses if needed, and formats for the target model.
Args:
project_id: Project identifier
agent_id: Agent identifier
query: User's query or current request
model: Target model name
max_tokens: Maximum context tokens (uses model default if None)
system_prompt: System prompt/instructions
task_description: Current task description
knowledge_query: Query for knowledge base search
knowledge_limit: Max number of knowledge results
conversation_history: List of {"role": str, "content": str}
tool_results: List of tool results to include
custom_contexts: Additional custom contexts
custom_budget: Custom token budget
compress: Whether to apply compression
format_output: Whether to format for the model
use_cache: Whether to use caching
Returns:
AssembledContext with optimized content
Raises:
AssemblyTimeoutError: If assembly exceeds timeout
BudgetExceededError: If context exceeds budget
"""
# Gather all contexts
contexts: list[BaseContext] = []
# 1. System context
if system_prompt:
contexts.append(
SystemContext(
content=system_prompt,
source="system_prompt",
)
)
# 2. Task context
if task_description:
contexts.append(
TaskContext(
content=task_description,
source=f"task:{project_id}:{agent_id}",
)
)
# 3. Knowledge context from Knowledge Base
if knowledge_query and self._mcp:
knowledge_contexts = await self._fetch_knowledge(
project_id=project_id,
agent_id=agent_id,
query=knowledge_query,
limit=knowledge_limit,
)
contexts.extend(knowledge_contexts)
# 4. Conversation history
if conversation_history:
contexts.extend(self._convert_conversation(conversation_history))
# 5. Tool results
if tool_results:
contexts.extend(self._convert_tool_results(tool_results))
# 6. Custom contexts
if custom_contexts:
contexts.extend(custom_contexts)
# Check cache if enabled
if use_cache and self._cache.is_enabled:
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
cached = await self._cache.get_assembled(fingerprint)
if cached:
logger.debug(f"Cache hit for context assembly: {fingerprint}")
return cached
# Run assembly pipeline
result = await self._pipeline.assemble(
contexts=contexts,
query=query,
model=model,
max_tokens=max_tokens,
custom_budget=custom_budget,
compress=compress,
format_output=format_output,
)
# Cache result if enabled
if use_cache and self._cache.is_enabled:
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
await self._cache.set_assembled(fingerprint, result)
return result
async def _fetch_knowledge(
self,
project_id: str,
agent_id: str,
query: str,
limit: int = 10,
) -> list[KnowledgeContext]:
"""
Fetch relevant knowledge from Knowledge Base via MCP.
Args:
project_id: Project identifier
agent_id: Agent identifier
query: Search query
limit: Maximum results
Returns:
List of KnowledgeContext instances
"""
if not self._mcp:
return []
try:
result = await self._mcp.call_tool(
"knowledge-base",
"search_knowledge",
{
"project_id": project_id,
"agent_id": agent_id,
"query": query,
"search_type": "hybrid",
"limit": limit,
},
)
contexts = []
for chunk in result.data.get("results", []):
contexts.append(
KnowledgeContext(
content=chunk.get("content", ""),
source=chunk.get("source_path", "unknown"),
relevance_score=chunk.get("score", 0.0),
metadata={
"chunk_id": chunk.get("chunk_id"),
"document_id": chunk.get("document_id"),
},
)
)
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
return contexts
except Exception as e:
logger.warning(f"Failed to fetch knowledge: {e}")
return []
def _convert_conversation(
self,
history: list[dict[str, str]],
) -> list[ConversationContext]:
"""
Convert conversation history to ConversationContext instances.
Args:
history: List of {"role": str, "content": str}
Returns:
List of ConversationContext instances
"""
contexts = []
for i, turn in enumerate(history):
role_str = turn.get("role", "user").lower()
role = MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
contexts.append(
ConversationContext(
content=turn.get("content", ""),
source=f"conversation:{i}",
role=role,
metadata={"role": role_str, "turn": i},
)
)
return contexts
def _convert_tool_results(
self,
results: list[dict[str, Any]],
) -> list[ToolContext]:
"""
Convert tool results to ToolContext instances.
Args:
results: List of tool result dictionaries
Returns:
List of ToolContext instances
"""
contexts = []
for result in results:
tool_name = result.get("tool_name", "unknown")
content = result.get("content", result.get("result", ""))
# Handle dict content
if isinstance(content, dict):
import json
content = json.dumps(content, indent=2)
contexts.append(
ToolContext(
content=str(content),
source=f"tool:{tool_name}",
metadata={
"tool_name": tool_name,
"status": result.get("status", "success"),
},
)
)
return contexts
async def get_budget_for_model(
self,
model: str,
max_tokens: int | None = None,
) -> TokenBudget:
"""
Get the token budget for a specific model.
Args:
model: Model name
max_tokens: Optional max tokens override
Returns:
TokenBudget instance
"""
if max_tokens:
return self._allocator.create_budget(max_tokens)
return self._allocator.create_budget_for_model(model)
async def count_tokens(
self,
content: str,
model: str | None = None,
) -> int:
"""
Count tokens in content.
Args:
content: Content to count
model: Model for model-specific tokenization
Returns:
Token count
"""
# Check cache first
cached = await self._cache.get_token_count(content, model)
if cached is not None:
return cached
count = await self._calculator.count_tokens(content, model)
# Cache the result
await self._cache.set_token_count(content, count, model)
return count
async def invalidate_cache(
self,
project_id: str | None = None,
pattern: str | None = None,
) -> int:
"""
Invalidate cache entries.
Args:
project_id: Invalidate all cache for a project
pattern: Custom pattern to match
Returns:
Number of entries invalidated
"""
if pattern:
return await self._cache.invalidate(pattern)
elif project_id:
return await self._cache.invalidate(f"*{project_id}*")
else:
return await self._cache.clear_all()
async def get_stats(self) -> dict[str, Any]:
"""
Get engine statistics.
Returns:
Dictionary with engine stats
"""
return {
"cache": await self._cache.get_stats(),
"settings": {
"compression_threshold": self._settings.compression_threshold,
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
"cache_enabled": self._settings.cache_enabled,
},
}
# Convenience factory function
def create_context_engine(
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
) -> ContextEngine:
"""
Create a context engine instance.
Args:
mcp_manager: MCP client manager
redis: Redis connection
settings: Context settings
Returns:
Configured ContextEngine instance
"""
return ContextEngine(
mcp_manager=mcp_manager,
redis=redis,
settings=settings,
)

View File

@@ -0,0 +1,458 @@
"""Tests for ContextEngine."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.context.config import ContextSettings
from app.services.context.engine import ContextEngine, create_context_engine
from app.services.context.types import (
AssembledContext,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)
class TestContextEngineCreation:
"""Tests for ContextEngine creation."""
def test_creation_minimal(self) -> None:
"""Test creating engine with minimal config."""
engine = ContextEngine()
assert engine._mcp is None
assert engine._settings is not None
assert engine._calculator is not None
assert engine._scorer is not None
assert engine._ranker is not None
assert engine._compressor is not None
assert engine._cache is not None
assert engine._pipeline is not None
def test_creation_with_settings(self) -> None:
"""Test creating engine with custom settings."""
settings = ContextSettings(
compression_threshold=0.7,
cache_enabled=False,
)
engine = ContextEngine(settings=settings)
assert engine._settings.compression_threshold == 0.7
assert engine._settings.cache_enabled is False
def test_creation_with_redis(self) -> None:
"""Test creating engine with Redis."""
mock_redis = MagicMock()
settings = ContextSettings(cache_enabled=True)
engine = ContextEngine(redis=mock_redis, settings=settings)
assert engine._cache.is_enabled
def test_set_mcp_manager(self) -> None:
"""Test setting MCP manager."""
engine = ContextEngine()
mock_mcp = MagicMock()
engine.set_mcp_manager(mock_mcp)
assert engine._mcp is mock_mcp
def test_set_redis(self) -> None:
"""Test setting Redis connection."""
engine = ContextEngine()
mock_redis = MagicMock()
engine.set_redis(mock_redis)
assert engine._cache._redis is mock_redis
class TestContextEngineHelpers:
"""Tests for ContextEngine helper methods."""
def test_convert_conversation(self) -> None:
"""Test converting conversation history."""
engine = ContextEngine()
history = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]
contexts = engine._convert_conversation(history)
assert len(contexts) == 3
assert all(isinstance(c, ConversationContext) for c in contexts)
assert contexts[0].role == MessageRole.USER
assert contexts[1].role == MessageRole.ASSISTANT
assert contexts[0].content == "Hello!"
assert contexts[0].metadata["turn"] == 0
def test_convert_tool_results(self) -> None:
"""Test converting tool results."""
engine = ContextEngine()
results = [
{"tool_name": "search", "content": "Result 1", "status": "success"},
{"tool_name": "read", "result": {"file": "test.txt"}, "status": "success"},
]
contexts = engine._convert_tool_results(results)
assert len(contexts) == 2
assert all(isinstance(c, ToolContext) for c in contexts)
assert contexts[0].content == "Result 1"
assert contexts[0].metadata["tool_name"] == "search"
# Dict content should be JSON serialized
assert "file" in contexts[1].content
assert "test.txt" in contexts[1].content
class TestContextEngineAssembly:
"""Tests for context assembly."""
@pytest.mark.asyncio
async def test_assemble_minimal(self) -> None:
"""Test assembling with minimal inputs."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test query",
model="claude-3-sonnet",
use_cache=False, # Disable cache for test
)
assert isinstance(result, AssembledContext)
assert result.context_count == 0 # No contexts provided
@pytest.mark.asyncio
async def test_assemble_with_system_prompt(self) -> None:
"""Test assembling with system prompt."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test query",
model="claude-3-sonnet",
system_prompt="You are a helpful assistant.",
use_cache=False,
)
assert result.context_count == 1
assert "helpful assistant" in result.content
@pytest.mark.asyncio
async def test_assemble_with_task(self) -> None:
"""Test assembling with task description."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="implement feature",
model="claude-3-sonnet",
task_description="Implement user authentication",
use_cache=False,
)
assert result.context_count == 1
assert "authentication" in result.content
@pytest.mark.asyncio
async def test_assemble_with_conversation(self) -> None:
"""Test assembling with conversation history."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="continue",
model="claude-3-sonnet",
conversation_history=[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi!"},
],
use_cache=False,
)
assert result.context_count == 2
assert "Hello" in result.content
@pytest.mark.asyncio
async def test_assemble_with_tool_results(self) -> None:
"""Test assembling with tool results."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="continue",
model="claude-3-sonnet",
tool_results=[
{"tool_name": "search", "content": "Found 5 results"},
],
use_cache=False,
)
assert result.context_count == 1
assert "Found 5 results" in result.content
@pytest.mark.asyncio
async def test_assemble_with_custom_contexts(self) -> None:
"""Test assembling with custom contexts."""
engine = ContextEngine()
custom = [
KnowledgeContext(
content="Custom knowledge.",
source="custom",
relevance_score=0.9,
)
]
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test",
model="claude-3-sonnet",
custom_contexts=custom,
use_cache=False,
)
assert result.context_count == 1
assert "Custom knowledge" in result.content
@pytest.mark.asyncio
async def test_assemble_full_workflow(self) -> None:
"""Test full assembly workflow."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="implement login",
model="claude-3-sonnet",
system_prompt="You are an expert Python developer.",
task_description="Implement user authentication.",
conversation_history=[
{"role": "user", "content": "Can you help me implement JWT auth?"},
],
tool_results=[
{"tool_name": "file_create", "content": "Created auth.py"},
],
use_cache=False,
)
assert result.context_count >= 4
assert result.total_tokens > 0
assert result.model == "claude-3-sonnet"
# Check for expected content
assert "expert Python developer" in result.content
assert "authentication" in result.content
class TestContextEngineKnowledge:
"""Tests for knowledge fetching."""
@pytest.mark.asyncio
async def test_fetch_knowledge_no_mcp(self) -> None:
"""Test fetching knowledge without MCP returns empty."""
engine = ContextEngine()
result = await engine._fetch_knowledge(
project_id="proj-123",
agent_id="agent-456",
query="test",
)
assert result == []
@pytest.mark.asyncio
async def test_fetch_knowledge_with_mcp(self) -> None:
"""Test fetching knowledge with MCP."""
mock_mcp = AsyncMock()
mock_mcp.call_tool.return_value.data = {
"results": [
{
"content": "Document content",
"source_path": "docs/api.md",
"score": 0.9,
"chunk_id": "chunk-1",
},
{
"content": "Another document",
"source_path": "docs/auth.md",
"score": 0.8,
},
]
}
engine = ContextEngine(mcp_manager=mock_mcp)
result = await engine._fetch_knowledge(
project_id="proj-123",
agent_id="agent-456",
query="authentication",
)
assert len(result) == 2
assert all(isinstance(c, KnowledgeContext) for c in result)
assert result[0].content == "Document content"
assert result[0].source == "docs/api.md"
assert result[0].relevance_score == 0.9
@pytest.mark.asyncio
async def test_fetch_knowledge_error_handling(self) -> None:
"""Test knowledge fetch error handling."""
mock_mcp = AsyncMock()
mock_mcp.call_tool.side_effect = Exception("MCP error")
engine = ContextEngine(mcp_manager=mock_mcp)
# Should not raise, returns empty
result = await engine._fetch_knowledge(
project_id="proj-123",
agent_id="agent-456",
query="test",
)
assert result == []
class TestContextEngineCaching:
"""Tests for caching behavior."""
@pytest.mark.asyncio
async def test_cache_disabled(self) -> None:
"""Test assembly with cache disabled."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test",
model="claude-3-sonnet",
system_prompt="Test prompt",
use_cache=False,
)
assert not result.cache_hit
@pytest.mark.asyncio
async def test_cache_hit(self) -> None:
"""Test cache hit."""
mock_redis = AsyncMock()
settings = ContextSettings(cache_enabled=True)
engine = ContextEngine(redis=mock_redis, settings=settings)
# First call - cache miss
mock_redis.get.return_value = None
result1 = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test",
model="claude-3-sonnet",
system_prompt="Test prompt",
)
# Second call - mock cache hit
mock_redis.get.return_value = result1.to_json()
result2 = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test",
model="claude-3-sonnet",
system_prompt="Test prompt",
)
assert result2.cache_hit
class TestContextEngineUtilities:
"""Tests for utility methods."""
@pytest.mark.asyncio
async def test_get_budget_for_model(self) -> None:
"""Test getting budget for model."""
engine = ContextEngine()
budget = await engine.get_budget_for_model("claude-3-sonnet")
assert budget.total > 0
assert budget.system > 0
assert budget.knowledge > 0
@pytest.mark.asyncio
async def test_get_budget_with_max_tokens(self) -> None:
"""Test getting budget with max tokens."""
engine = ContextEngine()
budget = await engine.get_budget_for_model("claude-3-sonnet", max_tokens=5000)
assert budget.total == 5000
@pytest.mark.asyncio
async def test_count_tokens(self) -> None:
"""Test token counting."""
engine = ContextEngine()
count = await engine.count_tokens("Hello world")
assert count > 0
@pytest.mark.asyncio
async def test_invalidate_cache(self) -> None:
"""Test cache invalidation."""
mock_redis = AsyncMock()
async def mock_scan_iter(match=None):
for key in ["ctx:1", "ctx:2"]:
yield key
mock_redis.scan_iter = mock_scan_iter
settings = ContextSettings(cache_enabled=True)
engine = ContextEngine(redis=mock_redis, settings=settings)
deleted = await engine.invalidate_cache(pattern="*test*")
assert deleted >= 0
@pytest.mark.asyncio
async def test_get_stats(self) -> None:
"""Test getting engine stats."""
engine = ContextEngine()
stats = await engine.get_stats()
assert "cache" in stats
assert "settings" in stats
assert "compression_threshold" in stats["settings"]
class TestCreateContextEngine:
"""Tests for factory function."""
def test_create_context_engine(self) -> None:
"""Test factory function."""
engine = create_context_engine()
assert isinstance(engine, ContextEngine)
def test_create_context_engine_with_settings(self) -> None:
"""Test factory with settings."""
settings = ContextSettings(cache_enabled=False)
engine = create_context_engine(settings=settings)
assert engine._settings.cache_enabled is False