forked from cardosofelipe/fast-next-template
feat(context): implement main ContextEngine with full integration (#85)
Phase 7 of Context Management Engine - Main Engine: - Add ContextEngine as main orchestration class - Integrate all components: calculator, scorer, ranker, compressor, cache - Add high-level assemble_context() API with: - System prompt support - Task description support - Knowledge Base integration via MCP - Conversation history conversion - Tool results conversion - Custom contexts support - Add helper methods: - get_budget_for_model() - count_tokens() with caching - invalidate_cache() - get_stats() - Add create_context_engine() factory function Tests: 26 new tests, 311 total context tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -88,6 +88,9 @@ from .adapters import (
|
|||||||
# Cache
|
# Cache
|
||||||
from .cache import ContextCache
|
from .cache import ContextCache
|
||||||
|
|
||||||
|
# Engine
|
||||||
|
from .engine import ContextEngine, create_context_engine
|
||||||
|
|
||||||
# Prioritization
|
# Prioritization
|
||||||
from .prioritization import (
|
from .prioritization import (
|
||||||
ContextRanker,
|
ContextRanker,
|
||||||
@@ -137,6 +140,9 @@ __all__ = [
|
|||||||
"TokenCalculator",
|
"TokenCalculator",
|
||||||
# Cache
|
# Cache
|
||||||
"ContextCache",
|
"ContextCache",
|
||||||
|
# Engine
|
||||||
|
"ContextEngine",
|
||||||
|
"create_context_engine",
|
||||||
# Compression
|
# Compression
|
||||||
"ContextCompressor",
|
"ContextCompressor",
|
||||||
"TruncationResult",
|
"TruncationResult",
|
||||||
|
|||||||
470
backend/app/services/context/engine.py
Normal file
470
backend/app/services/context/engine.py
Normal file
@@ -0,0 +1,470 @@
|
|||||||
|
"""
|
||||||
|
Context Management Engine.
|
||||||
|
|
||||||
|
Main orchestration layer for context assembly and optimization.
|
||||||
|
Provides a high-level API for assembling optimized context for LLM requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from .adapters import get_adapter
|
||||||
|
from .assembly import ContextPipeline
|
||||||
|
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||||
|
from .cache import ContextCache
|
||||||
|
from .compression import ContextCompressor
|
||||||
|
from .config import ContextSettings, get_context_settings
|
||||||
|
from .prioritization import ContextRanker
|
||||||
|
from .scoring import CompositeScorer
|
||||||
|
from .types import (
|
||||||
|
AssembledContext,
|
||||||
|
BaseContext,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
ToolContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextEngine:
|
||||||
|
"""
|
||||||
|
Main context management engine.
|
||||||
|
|
||||||
|
Provides high-level API for context assembly and optimization.
|
||||||
|
Integrates all components: scoring, ranking, compression, formatting, and caching.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
engine = ContextEngine(mcp_manager=mcp, redis=redis)
|
||||||
|
|
||||||
|
# Assemble context for an LLM request
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="implement user authentication",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="You are an expert developer.",
|
||||||
|
knowledge_query="authentication best practices",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the assembled context
|
||||||
|
print(result.content)
|
||||||
|
print(f"Tokens: {result.total_tokens}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the context engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||||
|
redis: Redis connection for caching
|
||||||
|
settings: Context settings
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||||
|
self._scorer = CompositeScorer(
|
||||||
|
mcp_manager=mcp_manager, settings=self._settings
|
||||||
|
)
|
||||||
|
self._ranker = ContextRanker(
|
||||||
|
scorer=self._scorer, calculator=self._calculator
|
||||||
|
)
|
||||||
|
self._compressor = ContextCompressor(calculator=self._calculator)
|
||||||
|
self._allocator = BudgetAllocator(self._settings)
|
||||||
|
self._cache = ContextCache(redis=redis, settings=self._settings)
|
||||||
|
|
||||||
|
# Pipeline for assembly
|
||||||
|
self._pipeline = ContextPipeline(
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
settings=self._settings,
|
||||||
|
calculator=self._calculator,
|
||||||
|
scorer=self._scorer,
|
||||||
|
ranker=self._ranker,
|
||||||
|
compressor=self._compressor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""
|
||||||
|
Set MCP manager for all components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._calculator.set_mcp_manager(mcp_manager)
|
||||||
|
self._scorer.set_mcp_manager(mcp_manager)
|
||||||
|
self._pipeline.set_mcp_manager(mcp_manager)
|
||||||
|
|
||||||
|
def set_redis(self, redis: "Redis") -> None:
|
||||||
|
"""
|
||||||
|
Set Redis connection for caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis connection
|
||||||
|
"""
|
||||||
|
self._cache.set_redis(redis)
|
||||||
|
|
||||||
|
async def assemble_context(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
query: str,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
task_description: str | None = None,
|
||||||
|
knowledge_query: str | None = None,
|
||||||
|
knowledge_limit: int = 10,
|
||||||
|
conversation_history: list[dict[str, str]] | None = None,
|
||||||
|
tool_results: list[dict[str, Any]] | None = None,
|
||||||
|
custom_contexts: list[BaseContext] | None = None,
|
||||||
|
custom_budget: TokenBudget | None = None,
|
||||||
|
compress: bool = True,
|
||||||
|
format_output: bool = True,
|
||||||
|
use_cache: bool = True,
|
||||||
|
) -> AssembledContext:
|
||||||
|
"""
|
||||||
|
Assemble optimized context for an LLM request.
|
||||||
|
|
||||||
|
This is the main entry point for context management.
|
||||||
|
It gathers context from various sources, scores and ranks them,
|
||||||
|
compresses if needed, and formats for the target model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
query: User's query or current request
|
||||||
|
model: Target model name
|
||||||
|
max_tokens: Maximum context tokens (uses model default if None)
|
||||||
|
system_prompt: System prompt/instructions
|
||||||
|
task_description: Current task description
|
||||||
|
knowledge_query: Query for knowledge base search
|
||||||
|
knowledge_limit: Max number of knowledge results
|
||||||
|
conversation_history: List of {"role": str, "content": str}
|
||||||
|
tool_results: List of tool results to include
|
||||||
|
custom_contexts: Additional custom contexts
|
||||||
|
custom_budget: Custom token budget
|
||||||
|
compress: Whether to apply compression
|
||||||
|
format_output: Whether to format for the model
|
||||||
|
use_cache: Whether to use caching
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssembledContext with optimized content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssemblyTimeoutError: If assembly exceeds timeout
|
||||||
|
BudgetExceededError: If context exceeds budget
|
||||||
|
"""
|
||||||
|
# Gather all contexts
|
||||||
|
contexts: list[BaseContext] = []
|
||||||
|
|
||||||
|
# 1. System context
|
||||||
|
if system_prompt:
|
||||||
|
contexts.append(
|
||||||
|
SystemContext(
|
||||||
|
content=system_prompt,
|
||||||
|
source="system_prompt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Task context
|
||||||
|
if task_description:
|
||||||
|
contexts.append(
|
||||||
|
TaskContext(
|
||||||
|
content=task_description,
|
||||||
|
source=f"task:{project_id}:{agent_id}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Knowledge context from Knowledge Base
|
||||||
|
if knowledge_query and self._mcp:
|
||||||
|
knowledge_contexts = await self._fetch_knowledge(
|
||||||
|
project_id=project_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
query=knowledge_query,
|
||||||
|
limit=knowledge_limit,
|
||||||
|
)
|
||||||
|
contexts.extend(knowledge_contexts)
|
||||||
|
|
||||||
|
# 4. Conversation history
|
||||||
|
if conversation_history:
|
||||||
|
contexts.extend(self._convert_conversation(conversation_history))
|
||||||
|
|
||||||
|
# 5. Tool results
|
||||||
|
if tool_results:
|
||||||
|
contexts.extend(self._convert_tool_results(tool_results))
|
||||||
|
|
||||||
|
# 6. Custom contexts
|
||||||
|
if custom_contexts:
|
||||||
|
contexts.extend(custom_contexts)
|
||||||
|
|
||||||
|
# Check cache if enabled
|
||||||
|
if use_cache and self._cache.is_enabled:
|
||||||
|
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
||||||
|
cached = await self._cache.get_assembled(fingerprint)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Run assembly pipeline
|
||||||
|
result = await self._pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query=query,
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
custom_budget=custom_budget,
|
||||||
|
compress=compress,
|
||||||
|
format_output=format_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache result if enabled
|
||||||
|
if use_cache and self._cache.is_enabled:
|
||||||
|
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
||||||
|
await self._cache.set_assembled(fingerprint, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _fetch_knowledge(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[KnowledgeContext]:
|
||||||
|
"""
|
||||||
|
Fetch relevant knowledge from Knowledge Base via MCP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
query: Search query
|
||||||
|
limit: Maximum results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of KnowledgeContext instances
|
||||||
|
"""
|
||||||
|
if not self._mcp:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._mcp.call_tool(
|
||||||
|
"knowledge-base",
|
||||||
|
"search_knowledge",
|
||||||
|
{
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"query": query,
|
||||||
|
"search_type": "hybrid",
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts = []
|
||||||
|
for chunk in result.data.get("results", []):
|
||||||
|
contexts.append(
|
||||||
|
KnowledgeContext(
|
||||||
|
content=chunk.get("content", ""),
|
||||||
|
source=chunk.get("source_path", "unknown"),
|
||||||
|
relevance_score=chunk.get("score", 0.0),
|
||||||
|
metadata={
|
||||||
|
"chunk_id": chunk.get("chunk_id"),
|
||||||
|
"document_id": chunk.get("document_id"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _convert_conversation(
|
||||||
|
self,
|
||||||
|
history: list[dict[str, str]],
|
||||||
|
) -> list[ConversationContext]:
|
||||||
|
"""
|
||||||
|
Convert conversation history to ConversationContext instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history: List of {"role": str, "content": str}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ConversationContext instances
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
for i, turn in enumerate(history):
|
||||||
|
role_str = turn.get("role", "user").lower()
|
||||||
|
role = MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
||||||
|
|
||||||
|
contexts.append(
|
||||||
|
ConversationContext(
|
||||||
|
content=turn.get("content", ""),
|
||||||
|
source=f"conversation:{i}",
|
||||||
|
role=role,
|
||||||
|
metadata={"role": role_str, "turn": i},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
def _convert_tool_results(
|
||||||
|
self,
|
||||||
|
results: list[dict[str, Any]],
|
||||||
|
) -> list[ToolContext]:
|
||||||
|
"""
|
||||||
|
Convert tool results to ToolContext instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of tool result dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ToolContext instances
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
for result in results:
|
||||||
|
tool_name = result.get("tool_name", "unknown")
|
||||||
|
content = result.get("content", result.get("result", ""))
|
||||||
|
|
||||||
|
# Handle dict content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
import json
|
||||||
|
content = json.dumps(content, indent=2)
|
||||||
|
|
||||||
|
contexts.append(
|
||||||
|
ToolContext(
|
||||||
|
content=str(content),
|
||||||
|
source=f"tool:{tool_name}",
|
||||||
|
metadata={
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"status": result.get("status", "success"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
async def get_budget_for_model(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Get the token budget for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
max_tokens: Optional max tokens override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBudget instance
|
||||||
|
"""
|
||||||
|
if max_tokens:
|
||||||
|
return self._allocator.create_budget(max_tokens)
|
||||||
|
return self._allocator.create_budget_for_model(model)
|
||||||
|
|
||||||
|
async def count_tokens(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count tokens in content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to count
|
||||||
|
model: Model for model-specific tokenization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token count
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
cached = await self._cache.get_token_count(content, model)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
count = await self._calculator.count_tokens(content, model)
|
||||||
|
|
||||||
|
# Cache the result
|
||||||
|
await self._cache.set_token_count(content, count, model)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def invalidate_cache(
|
||||||
|
self,
|
||||||
|
project_id: str | None = None,
|
||||||
|
pattern: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate cache entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Invalidate all cache for a project
|
||||||
|
pattern: Custom pattern to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
if pattern:
|
||||||
|
return await self._cache.invalidate(pattern)
|
||||||
|
elif project_id:
|
||||||
|
return await self._cache.invalidate(f"*{project_id}*")
|
||||||
|
else:
|
||||||
|
return await self._cache.clear_all()
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get engine statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with engine stats
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"cache": await self._cache.get_stats(),
|
||||||
|
"settings": {
|
||||||
|
"compression_threshold": self._settings.compression_threshold,
|
||||||
|
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
|
||||||
|
"cache_enabled": self._settings.cache_enabled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience factory function
|
||||||
|
def create_context_engine(
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> ContextEngine:
|
||||||
|
"""
|
||||||
|
Create a context engine instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager
|
||||||
|
redis: Redis connection
|
||||||
|
settings: Context settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured ContextEngine instance
|
||||||
|
"""
|
||||||
|
return ContextEngine(
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
redis=redis,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
458
backend/tests/services/context/test_engine.py
Normal file
458
backend/tests/services/context/test_engine.py
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
"""Tests for ContextEngine."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.config import ContextSettings
|
||||||
|
from app.services.context.engine import ContextEngine, create_context_engine
|
||||||
|
from app.services.context.types import (
|
||||||
|
AssembledContext,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
ToolContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineCreation:
|
||||||
|
"""Tests for ContextEngine creation."""
|
||||||
|
|
||||||
|
def test_creation_minimal(self) -> None:
|
||||||
|
"""Test creating engine with minimal config."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
assert engine._mcp is None
|
||||||
|
assert engine._settings is not None
|
||||||
|
assert engine._calculator is not None
|
||||||
|
assert engine._scorer is not None
|
||||||
|
assert engine._ranker is not None
|
||||||
|
assert engine._compressor is not None
|
||||||
|
assert engine._cache is not None
|
||||||
|
assert engine._pipeline is not None
|
||||||
|
|
||||||
|
def test_creation_with_settings(self) -> None:
|
||||||
|
"""Test creating engine with custom settings."""
|
||||||
|
settings = ContextSettings(
|
||||||
|
compression_threshold=0.7,
|
||||||
|
cache_enabled=False,
|
||||||
|
)
|
||||||
|
engine = ContextEngine(settings=settings)
|
||||||
|
|
||||||
|
assert engine._settings.compression_threshold == 0.7
|
||||||
|
assert engine._settings.cache_enabled is False
|
||||||
|
|
||||||
|
def test_creation_with_redis(self) -> None:
|
||||||
|
"""Test creating engine with Redis."""
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
assert engine._cache.is_enabled
|
||||||
|
|
||||||
|
def test_set_mcp_manager(self) -> None:
|
||||||
|
"""Test setting MCP manager."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
|
||||||
|
engine.set_mcp_manager(mock_mcp)
|
||||||
|
|
||||||
|
assert engine._mcp is mock_mcp
|
||||||
|
|
||||||
|
def test_set_redis(self) -> None:
|
||||||
|
"""Test setting Redis connection."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
|
||||||
|
engine.set_redis(mock_redis)
|
||||||
|
|
||||||
|
assert engine._cache._redis is mock_redis
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineHelpers:
|
||||||
|
"""Tests for ContextEngine helper methods."""
|
||||||
|
|
||||||
|
def test_convert_conversation(self) -> None:
|
||||||
|
"""Test converting conversation history."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
history = [
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
{"role": "user", "content": "How are you?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
contexts = engine._convert_conversation(history)
|
||||||
|
|
||||||
|
assert len(contexts) == 3
|
||||||
|
assert all(isinstance(c, ConversationContext) for c in contexts)
|
||||||
|
assert contexts[0].role == MessageRole.USER
|
||||||
|
assert contexts[1].role == MessageRole.ASSISTANT
|
||||||
|
assert contexts[0].content == "Hello!"
|
||||||
|
assert contexts[0].metadata["turn"] == 0
|
||||||
|
|
||||||
|
def test_convert_tool_results(self) -> None:
|
||||||
|
"""Test converting tool results."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
results = [
|
||||||
|
{"tool_name": "search", "content": "Result 1", "status": "success"},
|
||||||
|
{"tool_name": "read", "result": {"file": "test.txt"}, "status": "success"},
|
||||||
|
]
|
||||||
|
|
||||||
|
contexts = engine._convert_tool_results(results)
|
||||||
|
|
||||||
|
assert len(contexts) == 2
|
||||||
|
assert all(isinstance(c, ToolContext) for c in contexts)
|
||||||
|
assert contexts[0].content == "Result 1"
|
||||||
|
assert contexts[0].metadata["tool_name"] == "search"
|
||||||
|
# Dict content should be JSON serialized
|
||||||
|
assert "file" in contexts[1].content
|
||||||
|
assert "test.txt" in contexts[1].content
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineAssembly:
|
||||||
|
"""Tests for context assembly."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_minimal(self) -> None:
|
||||||
|
"""Test assembling with minimal inputs."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test query",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
use_cache=False, # Disable cache for test
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, AssembledContext)
|
||||||
|
assert result.context_count == 0 # No contexts provided
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_system_prompt(self) -> None:
|
||||||
|
"""Test assembling with system prompt."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test query",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="You are a helpful assistant.",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert "helpful assistant" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_task(self) -> None:
|
||||||
|
"""Test assembling with task description."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="implement feature",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
task_description="Implement user authentication",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert "authentication" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_conversation(self) -> None:
|
||||||
|
"""Test assembling with conversation history."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="continue",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
conversation_history=[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "assistant", "content": "Hi!"},
|
||||||
|
],
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 2
|
||||||
|
assert "Hello" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_tool_results(self) -> None:
|
||||||
|
"""Test assembling with tool results."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="continue",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
tool_results=[
|
||||||
|
{"tool_name": "search", "content": "Found 5 results"},
|
||||||
|
],
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert "Found 5 results" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_custom_contexts(self) -> None:
|
||||||
|
"""Test assembling with custom contexts."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
custom = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Custom knowledge.",
|
||||||
|
source="custom",
|
||||||
|
relevance_score=0.9,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
custom_contexts=custom,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert "Custom knowledge" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_full_workflow(self) -> None:
|
||||||
|
"""Test full assembly workflow."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="implement login",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="You are an expert Python developer.",
|
||||||
|
task_description="Implement user authentication.",
|
||||||
|
conversation_history=[
|
||||||
|
{"role": "user", "content": "Can you help me implement JWT auth?"},
|
||||||
|
],
|
||||||
|
tool_results=[
|
||||||
|
{"tool_name": "file_create", "content": "Created auth.py"},
|
||||||
|
],
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count >= 4
|
||||||
|
assert result.total_tokens > 0
|
||||||
|
assert result.model == "claude-3-sonnet"
|
||||||
|
|
||||||
|
# Check for expected content
|
||||||
|
assert "expert Python developer" in result.content
|
||||||
|
assert "authentication" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineKnowledge:
|
||||||
|
"""Tests for knowledge fetching."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_knowledge_no_mcp(self) -> None:
|
||||||
|
"""Test fetching knowledge without MCP returns empty."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine._fetch_knowledge(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_knowledge_with_mcp(self) -> None:
|
||||||
|
"""Test fetching knowledge with MCP."""
|
||||||
|
mock_mcp = AsyncMock()
|
||||||
|
mock_mcp.call_tool.return_value.data = {
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"content": "Document content",
|
||||||
|
"source_path": "docs/api.md",
|
||||||
|
"score": 0.9,
|
||||||
|
"chunk_id": "chunk-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "Another document",
|
||||||
|
"source_path": "docs/auth.md",
|
||||||
|
"score": 0.8,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
engine = ContextEngine(mcp_manager=mock_mcp)
|
||||||
|
|
||||||
|
result = await engine._fetch_knowledge(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="authentication",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(isinstance(c, KnowledgeContext) for c in result)
|
||||||
|
assert result[0].content == "Document content"
|
||||||
|
assert result[0].source == "docs/api.md"
|
||||||
|
assert result[0].relevance_score == 0.9
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_knowledge_error_handling(self) -> None:
|
||||||
|
"""Test knowledge fetch error handling."""
|
||||||
|
mock_mcp = AsyncMock()
|
||||||
|
mock_mcp.call_tool.side_effect = Exception("MCP error")
|
||||||
|
|
||||||
|
engine = ContextEngine(mcp_manager=mock_mcp)
|
||||||
|
|
||||||
|
# Should not raise, returns empty
|
||||||
|
result = await engine._fetch_knowledge(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineCaching:
|
||||||
|
"""Tests for caching behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_disabled(self) -> None:
|
||||||
|
"""Test assembly with cache disabled."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="Test prompt",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not result.cache_hit
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit(self) -> None:
|
||||||
|
"""Test cache hit."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
# First call - cache miss
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
result1 = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="Test prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second call - mock cache hit
|
||||||
|
mock_redis.get.return_value = result1.to_json()
|
||||||
|
|
||||||
|
result2 = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="Test prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result2.cache_hit
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineUtilities:
|
||||||
|
"""Tests for utility methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_budget_for_model(self) -> None:
|
||||||
|
"""Test getting budget for model."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
budget = await engine.get_budget_for_model("claude-3-sonnet")
|
||||||
|
|
||||||
|
assert budget.total > 0
|
||||||
|
assert budget.system > 0
|
||||||
|
assert budget.knowledge > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_budget_with_max_tokens(self) -> None:
|
||||||
|
"""Test getting budget with max tokens."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
budget = await engine.get_budget_for_model("claude-3-sonnet", max_tokens=5000)
|
||||||
|
|
||||||
|
assert budget.total == 5000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens(self) -> None:
|
||||||
|
"""Test token counting."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
count = await engine.count_tokens("Hello world")
|
||||||
|
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalidate_cache(self) -> None:
|
||||||
|
"""Test cache invalidation."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_scan_iter(match=None):
|
||||||
|
for key in ["ctx:1", "ctx:2"]:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
mock_redis.scan_iter = mock_scan_iter
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
deleted = await engine.invalidate_cache(pattern="*test*")
|
||||||
|
|
||||||
|
assert deleted >= 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats(self) -> None:
|
||||||
|
"""Test getting engine stats."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
stats = await engine.get_stats()
|
||||||
|
|
||||||
|
assert "cache" in stats
|
||||||
|
assert "settings" in stats
|
||||||
|
assert "compression_threshold" in stats["settings"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateContextEngine:
|
||||||
|
"""Tests for factory function."""
|
||||||
|
|
||||||
|
def test_create_context_engine(self) -> None:
|
||||||
|
"""Test factory function."""
|
||||||
|
engine = create_context_engine()
|
||||||
|
|
||||||
|
assert isinstance(engine, ContextEngine)
|
||||||
|
|
||||||
|
def test_create_context_engine_with_settings(self) -> None:
|
||||||
|
"""Test factory with settings."""
|
||||||
|
settings = ContextSettings(cache_enabled=False)
|
||||||
|
engine = create_context_engine(settings=settings)
|
||||||
|
|
||||||
|
assert engine._settings.cache_enabled is False
|
||||||
Reference in New Issue
Block a user