forked from cardosofelipe/fast-next-template
feat(context): implement main ContextEngine with full integration (#85)
Phase 7 of Context Management Engine - Main Engine: - Add ContextEngine as main orchestration class - Integrate all components: calculator, scorer, ranker, compressor, cache - Add high-level assemble_context() API with: - System prompt support - Task description support - Knowledge Base integration via MCP - Conversation history conversion - Tool results conversion - Custom contexts support - Add helper methods: - get_budget_for_model() - count_tokens() with caching - invalidate_cache() - get_stats() - Add create_context_engine() factory function Tests: 26 new tests, 311 total context tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -88,6 +88,9 @@ from .adapters import (
|
||||
# Cache
|
||||
from .cache import ContextCache
|
||||
|
||||
# Engine
|
||||
from .engine import ContextEngine, create_context_engine
|
||||
|
||||
# Prioritization
|
||||
from .prioritization import (
|
||||
ContextRanker,
|
||||
@@ -137,6 +140,9 @@ __all__ = [
|
||||
"TokenCalculator",
|
||||
# Cache
|
||||
"ContextCache",
|
||||
# Engine
|
||||
"ContextEngine",
|
||||
"create_context_engine",
|
||||
# Compression
|
||||
"ContextCompressor",
|
||||
"TruncationResult",
|
||||
|
||||
470
backend/app/services/context/engine.py
Normal file
470
backend/app/services/context/engine.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
Context Management Engine.
|
||||
|
||||
Main orchestration layer for context assembly and optimization.
|
||||
Provides a high-level API for assembling optimized context for LLM requests.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .adapters import get_adapter
|
||||
from .assembly import ContextPipeline
|
||||
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from .cache import ContextCache
|
||||
from .compression import ContextCompressor
|
||||
from .config import ContextSettings, get_context_settings
|
||||
from .prioritization import ContextRanker
|
||||
from .scoring import CompositeScorer
|
||||
from .types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextEngine:
|
||||
"""
|
||||
Main context management engine.
|
||||
|
||||
Provides high-level API for context assembly and optimization.
|
||||
Integrates all components: scoring, ranking, compression, formatting, and caching.
|
||||
|
||||
Usage:
|
||||
engine = ContextEngine(mcp_manager=mcp, redis=redis)
|
||||
|
||||
# Assemble context for an LLM request
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="implement user authentication",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="You are an expert developer.",
|
||||
knowledge_query="authentication best practices",
|
||||
)
|
||||
|
||||
# Use the assembled context
|
||||
print(result.content)
|
||||
print(f"Tokens: {result.total_tokens}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context engine.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||
redis: Redis connection for caching
|
||||
settings: Context settings
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._settings = settings or get_context_settings()
|
||||
|
||||
# Initialize components
|
||||
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||
self._scorer = CompositeScorer(
|
||||
mcp_manager=mcp_manager, settings=self._settings
|
||||
)
|
||||
self._ranker = ContextRanker(
|
||||
scorer=self._scorer, calculator=self._calculator
|
||||
)
|
||||
self._compressor = ContextCompressor(calculator=self._calculator)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
self._cache = ContextCache(redis=redis, settings=self._settings)
|
||||
|
||||
# Pipeline for assembly
|
||||
self._pipeline = ContextPipeline(
|
||||
mcp_manager=mcp_manager,
|
||||
settings=self._settings,
|
||||
calculator=self._calculator,
|
||||
scorer=self._scorer,
|
||||
ranker=self._ranker,
|
||||
compressor=self._compressor,
|
||||
)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""
|
||||
Set MCP manager for all components.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._calculator.set_mcp_manager(mcp_manager)
|
||||
self._scorer.set_mcp_manager(mcp_manager)
|
||||
self._pipeline.set_mcp_manager(mcp_manager)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""
|
||||
Set Redis connection for caching.
|
||||
|
||||
Args:
|
||||
redis: Redis connection
|
||||
"""
|
||||
self._cache.set_redis(redis)
|
||||
|
||||
async def assemble_context(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_description: str | None = None,
|
||||
knowledge_query: str | None = None,
|
||||
knowledge_limit: int = 10,
|
||||
conversation_history: list[dict[str, str]] | None = None,
|
||||
tool_results: list[dict[str, Any]] | None = None,
|
||||
custom_contexts: list[BaseContext] | None = None,
|
||||
custom_budget: TokenBudget | None = None,
|
||||
compress: bool = True,
|
||||
format_output: bool = True,
|
||||
use_cache: bool = True,
|
||||
) -> AssembledContext:
|
||||
"""
|
||||
Assemble optimized context for an LLM request.
|
||||
|
||||
This is the main entry point for context management.
|
||||
It gathers context from various sources, scores and ranks them,
|
||||
compresses if needed, and formats for the target model.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: User's query or current request
|
||||
model: Target model name
|
||||
max_tokens: Maximum context tokens (uses model default if None)
|
||||
system_prompt: System prompt/instructions
|
||||
task_description: Current task description
|
||||
knowledge_query: Query for knowledge base search
|
||||
knowledge_limit: Max number of knowledge results
|
||||
conversation_history: List of {"role": str, "content": str}
|
||||
tool_results: List of tool results to include
|
||||
custom_contexts: Additional custom contexts
|
||||
custom_budget: Custom token budget
|
||||
compress: Whether to apply compression
|
||||
format_output: Whether to format for the model
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
AssembledContext with optimized content
|
||||
|
||||
Raises:
|
||||
AssemblyTimeoutError: If assembly exceeds timeout
|
||||
BudgetExceededError: If context exceeds budget
|
||||
"""
|
||||
# Gather all contexts
|
||||
contexts: list[BaseContext] = []
|
||||
|
||||
# 1. System context
|
||||
if system_prompt:
|
||||
contexts.append(
|
||||
SystemContext(
|
||||
content=system_prompt,
|
||||
source="system_prompt",
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Task context
|
||||
if task_description:
|
||||
contexts.append(
|
||||
TaskContext(
|
||||
content=task_description,
|
||||
source=f"task:{project_id}:{agent_id}",
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Knowledge context from Knowledge Base
|
||||
if knowledge_query and self._mcp:
|
||||
knowledge_contexts = await self._fetch_knowledge(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
query=knowledge_query,
|
||||
limit=knowledge_limit,
|
||||
)
|
||||
contexts.extend(knowledge_contexts)
|
||||
|
||||
# 4. Conversation history
|
||||
if conversation_history:
|
||||
contexts.extend(self._convert_conversation(conversation_history))
|
||||
|
||||
# 5. Tool results
|
||||
if tool_results:
|
||||
contexts.extend(self._convert_tool_results(tool_results))
|
||||
|
||||
# 6. Custom contexts
|
||||
if custom_contexts:
|
||||
contexts.extend(custom_contexts)
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache and self._cache.is_enabled:
|
||||
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
||||
cached = await self._cache.get_assembled(fingerprint)
|
||||
if cached:
|
||||
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||
return cached
|
||||
|
||||
# Run assembly pipeline
|
||||
result = await self._pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
custom_budget=custom_budget,
|
||||
compress=compress,
|
||||
format_output=format_output,
|
||||
)
|
||||
|
||||
# Cache result if enabled
|
||||
if use_cache and self._cache.is_enabled:
|
||||
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
||||
await self._cache.set_assembled(fingerprint, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _fetch_knowledge(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
) -> list[KnowledgeContext]:
|
||||
"""
|
||||
Fetch relevant knowledge from Knowledge Base via MCP.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of KnowledgeContext instances
|
||||
"""
|
||||
if not self._mcp:
|
||||
return []
|
||||
|
||||
try:
|
||||
result = await self._mcp.call_tool(
|
||||
"knowledge-base",
|
||||
"search_knowledge",
|
||||
{
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"query": query,
|
||||
"search_type": "hybrid",
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
contexts = []
|
||||
for chunk in result.data.get("results", []):
|
||||
contexts.append(
|
||||
KnowledgeContext(
|
||||
content=chunk.get("content", ""),
|
||||
source=chunk.get("source_path", "unknown"),
|
||||
relevance_score=chunk.get("score", 0.0),
|
||||
metadata={
|
||||
"chunk_id": chunk.get("chunk_id"),
|
||||
"document_id": chunk.get("document_id"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
|
||||
return contexts
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||
return []
|
||||
|
||||
def _convert_conversation(
|
||||
self,
|
||||
history: list[dict[str, str]],
|
||||
) -> list[ConversationContext]:
|
||||
"""
|
||||
Convert conversation history to ConversationContext instances.
|
||||
|
||||
Args:
|
||||
history: List of {"role": str, "content": str}
|
||||
|
||||
Returns:
|
||||
List of ConversationContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for i, turn in enumerate(history):
|
||||
role_str = turn.get("role", "user").lower()
|
||||
role = MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
||||
|
||||
contexts.append(
|
||||
ConversationContext(
|
||||
content=turn.get("content", ""),
|
||||
source=f"conversation:{i}",
|
||||
role=role,
|
||||
metadata={"role": role_str, "turn": i},
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
def _convert_tool_results(
|
||||
self,
|
||||
results: list[dict[str, Any]],
|
||||
) -> list[ToolContext]:
|
||||
"""
|
||||
Convert tool results to ToolContext instances.
|
||||
|
||||
Args:
|
||||
results: List of tool result dictionaries
|
||||
|
||||
Returns:
|
||||
List of ToolContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for result in results:
|
||||
tool_name = result.get("tool_name", "unknown")
|
||||
content = result.get("content", result.get("result", ""))
|
||||
|
||||
# Handle dict content
|
||||
if isinstance(content, dict):
|
||||
import json
|
||||
content = json.dumps(content, indent=2)
|
||||
|
||||
contexts.append(
|
||||
ToolContext(
|
||||
content=str(content),
|
||||
source=f"tool:{tool_name}",
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"status": result.get("status", "success"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def get_budget_for_model(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Get the token budget for a specific model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
TokenBudget instance
|
||||
"""
|
||||
if max_tokens:
|
||||
return self._allocator.create_budget(max_tokens)
|
||||
return self._allocator.create_budget_for_model(model)
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
content: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens in content.
|
||||
|
||||
Args:
|
||||
content: Content to count
|
||||
model: Model for model-specific tokenization
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
# Check cache first
|
||||
cached = await self._cache.get_token_count(content, model)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
count = await self._calculator.count_tokens(content, model)
|
||||
|
||||
# Cache the result
|
||||
await self._cache.set_token_count(content, count, model)
|
||||
|
||||
return count
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
project_id: str | None = None,
|
||||
pattern: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
project_id: Invalidate all cache for a project
|
||||
pattern: Custom pattern to match
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if pattern:
|
||||
return await self._cache.invalidate(pattern)
|
||||
elif project_id:
|
||||
return await self._cache.invalidate(f"*{project_id}*")
|
||||
else:
|
||||
return await self._cache.clear_all()
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get engine statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with engine stats
|
||||
"""
|
||||
return {
|
||||
"cache": await self._cache.get_stats(),
|
||||
"settings": {
|
||||
"compression_threshold": self._settings.compression_threshold,
|
||||
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
|
||||
"cache_enabled": self._settings.cache_enabled,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Convenience factory function
|
||||
def create_context_engine(
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> ContextEngine:
|
||||
"""
|
||||
Create a context engine instance.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager
|
||||
redis: Redis connection
|
||||
settings: Context settings
|
||||
|
||||
Returns:
|
||||
Configured ContextEngine instance
|
||||
"""
|
||||
return ContextEngine(
|
||||
mcp_manager=mcp_manager,
|
||||
redis=redis,
|
||||
settings=settings,
|
||||
)
|
||||
458
backend/tests/services/context/test_engine.py
Normal file
458
backend/tests/services/context/test_engine.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""Tests for ContextEngine."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.config import ContextSettings
|
||||
from app.services.context.engine import ContextEngine, create_context_engine
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
|
||||
class TestContextEngineCreation:
|
||||
"""Tests for ContextEngine creation."""
|
||||
|
||||
def test_creation_minimal(self) -> None:
|
||||
"""Test creating engine with minimal config."""
|
||||
engine = ContextEngine()
|
||||
|
||||
assert engine._mcp is None
|
||||
assert engine._settings is not None
|
||||
assert engine._calculator is not None
|
||||
assert engine._scorer is not None
|
||||
assert engine._ranker is not None
|
||||
assert engine._compressor is not None
|
||||
assert engine._cache is not None
|
||||
assert engine._pipeline is not None
|
||||
|
||||
def test_creation_with_settings(self) -> None:
|
||||
"""Test creating engine with custom settings."""
|
||||
settings = ContextSettings(
|
||||
compression_threshold=0.7,
|
||||
cache_enabled=False,
|
||||
)
|
||||
engine = ContextEngine(settings=settings)
|
||||
|
||||
assert engine._settings.compression_threshold == 0.7
|
||||
assert engine._settings.cache_enabled is False
|
||||
|
||||
def test_creation_with_redis(self) -> None:
|
||||
"""Test creating engine with Redis."""
|
||||
mock_redis = MagicMock()
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||
|
||||
assert engine._cache.is_enabled
|
||||
|
||||
def test_set_mcp_manager(self) -> None:
|
||||
"""Test setting MCP manager."""
|
||||
engine = ContextEngine()
|
||||
mock_mcp = MagicMock()
|
||||
|
||||
engine.set_mcp_manager(mock_mcp)
|
||||
|
||||
assert engine._mcp is mock_mcp
|
||||
|
||||
def test_set_redis(self) -> None:
|
||||
"""Test setting Redis connection."""
|
||||
engine = ContextEngine()
|
||||
mock_redis = MagicMock()
|
||||
|
||||
engine.set_redis(mock_redis)
|
||||
|
||||
assert engine._cache._redis is mock_redis
|
||||
|
||||
|
||||
class TestContextEngineHelpers:
|
||||
"""Tests for ContextEngine helper methods."""
|
||||
|
||||
def test_convert_conversation(self) -> None:
|
||||
"""Test converting conversation history."""
|
||||
engine = ContextEngine()
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
|
||||
contexts = engine._convert_conversation(history)
|
||||
|
||||
assert len(contexts) == 3
|
||||
assert all(isinstance(c, ConversationContext) for c in contexts)
|
||||
assert contexts[0].role == MessageRole.USER
|
||||
assert contexts[1].role == MessageRole.ASSISTANT
|
||||
assert contexts[0].content == "Hello!"
|
||||
assert contexts[0].metadata["turn"] == 0
|
||||
|
||||
def test_convert_tool_results(self) -> None:
|
||||
"""Test converting tool results."""
|
||||
engine = ContextEngine()
|
||||
|
||||
results = [
|
||||
{"tool_name": "search", "content": "Result 1", "status": "success"},
|
||||
{"tool_name": "read", "result": {"file": "test.txt"}, "status": "success"},
|
||||
]
|
||||
|
||||
contexts = engine._convert_tool_results(results)
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert all(isinstance(c, ToolContext) for c in contexts)
|
||||
assert contexts[0].content == "Result 1"
|
||||
assert contexts[0].metadata["tool_name"] == "search"
|
||||
# Dict content should be JSON serialized
|
||||
assert "file" in contexts[1].content
|
||||
assert "test.txt" in contexts[1].content
|
||||
|
||||
|
||||
class TestContextEngineAssembly:
|
||||
"""Tests for context assembly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_minimal(self) -> None:
|
||||
"""Test assembling with minimal inputs."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test query",
|
||||
model="claude-3-sonnet",
|
||||
use_cache=False, # Disable cache for test
|
||||
)
|
||||
|
||||
assert isinstance(result, AssembledContext)
|
||||
assert result.context_count == 0 # No contexts provided
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_system_prompt(self) -> None:
|
||||
"""Test assembling with system prompt."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test query",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert "helpful assistant" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_task(self) -> None:
|
||||
"""Test assembling with task description."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="implement feature",
|
||||
model="claude-3-sonnet",
|
||||
task_description="Implement user authentication",
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert "authentication" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_conversation(self) -> None:
|
||||
"""Test assembling with conversation history."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="continue",
|
||||
model="claude-3-sonnet",
|
||||
conversation_history=[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
],
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 2
|
||||
assert "Hello" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_tool_results(self) -> None:
|
||||
"""Test assembling with tool results."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="continue",
|
||||
model="claude-3-sonnet",
|
||||
tool_results=[
|
||||
{"tool_name": "search", "content": "Found 5 results"},
|
||||
],
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert "Found 5 results" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_custom_contexts(self) -> None:
|
||||
"""Test assembling with custom contexts."""
|
||||
engine = ContextEngine()
|
||||
|
||||
custom = [
|
||||
KnowledgeContext(
|
||||
content="Custom knowledge.",
|
||||
source="custom",
|
||||
relevance_score=0.9,
|
||||
)
|
||||
]
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
custom_contexts=custom,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert "Custom knowledge" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_full_workflow(self) -> None:
|
||||
"""Test full assembly workflow."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="implement login",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="You are an expert Python developer.",
|
||||
task_description="Implement user authentication.",
|
||||
conversation_history=[
|
||||
{"role": "user", "content": "Can you help me implement JWT auth?"},
|
||||
],
|
||||
tool_results=[
|
||||
{"tool_name": "file_create", "content": "Created auth.py"},
|
||||
],
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count >= 4
|
||||
assert result.total_tokens > 0
|
||||
assert result.model == "claude-3-sonnet"
|
||||
|
||||
# Check for expected content
|
||||
assert "expert Python developer" in result.content
|
||||
assert "authentication" in result.content
|
||||
|
||||
|
||||
class TestContextEngineKnowledge:
|
||||
"""Tests for knowledge fetching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_knowledge_no_mcp(self) -> None:
|
||||
"""Test fetching knowledge without MCP returns empty."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine._fetch_knowledge(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_knowledge_with_mcp(self) -> None:
|
||||
"""Test fetching knowledge with MCP."""
|
||||
mock_mcp = AsyncMock()
|
||||
mock_mcp.call_tool.return_value.data = {
|
||||
"results": [
|
||||
{
|
||||
"content": "Document content",
|
||||
"source_path": "docs/api.md",
|
||||
"score": 0.9,
|
||||
"chunk_id": "chunk-1",
|
||||
},
|
||||
{
|
||||
"content": "Another document",
|
||||
"source_path": "docs/auth.md",
|
||||
"score": 0.8,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
engine = ContextEngine(mcp_manager=mock_mcp)
|
||||
|
||||
result = await engine._fetch_knowledge(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="authentication",
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(c, KnowledgeContext) for c in result)
|
||||
assert result[0].content == "Document content"
|
||||
assert result[0].source == "docs/api.md"
|
||||
assert result[0].relevance_score == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_knowledge_error_handling(self) -> None:
|
||||
"""Test knowledge fetch error handling."""
|
||||
mock_mcp = AsyncMock()
|
||||
mock_mcp.call_tool.side_effect = Exception("MCP error")
|
||||
|
||||
engine = ContextEngine(mcp_manager=mock_mcp)
|
||||
|
||||
# Should not raise, returns empty
|
||||
result = await engine._fetch_knowledge(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestContextEngineCaching:
|
||||
"""Tests for caching behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_disabled(self) -> None:
|
||||
"""Test assembly with cache disabled."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="Test prompt",
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert not result.cache_hit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit(self) -> None:
|
||||
"""Test cache hit."""
|
||||
mock_redis = AsyncMock()
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||
|
||||
# First call - cache miss
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result1 = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="Test prompt",
|
||||
)
|
||||
|
||||
# Second call - mock cache hit
|
||||
mock_redis.get.return_value = result1.to_json()
|
||||
|
||||
result2 = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="Test prompt",
|
||||
)
|
||||
|
||||
assert result2.cache_hit
|
||||
|
||||
|
||||
class TestContextEngineUtilities:
|
||||
"""Tests for utility methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_budget_for_model(self) -> None:
|
||||
"""Test getting budget for model."""
|
||||
engine = ContextEngine()
|
||||
|
||||
budget = await engine.get_budget_for_model("claude-3-sonnet")
|
||||
|
||||
assert budget.total > 0
|
||||
assert budget.system > 0
|
||||
assert budget.knowledge > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_budget_with_max_tokens(self) -> None:
|
||||
"""Test getting budget with max tokens."""
|
||||
engine = ContextEngine()
|
||||
|
||||
budget = await engine.get_budget_for_model("claude-3-sonnet", max_tokens=5000)
|
||||
|
||||
assert budget.total == 5000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens(self) -> None:
|
||||
"""Test token counting."""
|
||||
engine = ContextEngine()
|
||||
|
||||
count = await engine.count_tokens("Hello world")
|
||||
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_cache(self) -> None:
|
||||
"""Test cache invalidation."""
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
async def mock_scan_iter(match=None):
|
||||
for key in ["ctx:1", "ctx:2"]:
|
||||
yield key
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||
|
||||
deleted = await engine.invalidate_cache(pattern="*test*")
|
||||
|
||||
assert deleted >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(self) -> None:
|
||||
"""Test getting engine stats."""
|
||||
engine = ContextEngine()
|
||||
|
||||
stats = await engine.get_stats()
|
||||
|
||||
assert "cache" in stats
|
||||
assert "settings" in stats
|
||||
assert "compression_threshold" in stats["settings"]
|
||||
|
||||
|
||||
class TestCreateContextEngine:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_create_context_engine(self) -> None:
|
||||
"""Test factory function."""
|
||||
engine = create_context_engine()
|
||||
|
||||
assert isinstance(engine, ContextEngine)
|
||||
|
||||
def test_create_context_engine_with_settings(self) -> None:
|
||||
"""Test factory with settings."""
|
||||
settings = ContextSettings(cache_enabled=False)
|
||||
engine = create_context_engine(settings=settings)
|
||||
|
||||
assert engine._settings.cache_enabled is False
|
||||
Reference in New Issue
Block a user