- Cleaned up unnecessary comments in `__all__` definitions for better readability. - Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping). - Simplified conditional expressions and inline comments for context scoring and ranking. - Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`). - Removed unused imports and ensured consistent usage across test files. - Updated `test_score_not_cached_on_context` to clarify caching behavior. - Improved truncation strategy logic and marker handling.
457 lines
14 KiB
Python
457 lines
14 KiB
Python
"""Tests for ContextEngine."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.services.context.config import ContextSettings
|
|
from app.services.context.engine import ContextEngine, create_context_engine
|
|
from app.services.context.types import (
|
|
AssembledContext,
|
|
ConversationContext,
|
|
KnowledgeContext,
|
|
MessageRole,
|
|
ToolContext,
|
|
)
|
|
|
|
|
|
class TestContextEngineCreation:
|
|
"""Tests for ContextEngine creation."""
|
|
|
|
def test_creation_minimal(self) -> None:
|
|
"""Test creating engine with minimal config."""
|
|
engine = ContextEngine()
|
|
|
|
assert engine._mcp is None
|
|
assert engine._settings is not None
|
|
assert engine._calculator is not None
|
|
assert engine._scorer is not None
|
|
assert engine._ranker is not None
|
|
assert engine._compressor is not None
|
|
assert engine._cache is not None
|
|
assert engine._pipeline is not None
|
|
|
|
def test_creation_with_settings(self) -> None:
|
|
"""Test creating engine with custom settings."""
|
|
settings = ContextSettings(
|
|
compression_threshold=0.7,
|
|
cache_enabled=False,
|
|
)
|
|
engine = ContextEngine(settings=settings)
|
|
|
|
assert engine._settings.compression_threshold == 0.7
|
|
assert engine._settings.cache_enabled is False
|
|
|
|
def test_creation_with_redis(self) -> None:
|
|
"""Test creating engine with Redis."""
|
|
mock_redis = MagicMock()
|
|
settings = ContextSettings(cache_enabled=True)
|
|
engine = ContextEngine(redis=mock_redis, settings=settings)
|
|
|
|
assert engine._cache.is_enabled
|
|
|
|
def test_set_mcp_manager(self) -> None:
|
|
"""Test setting MCP manager."""
|
|
engine = ContextEngine()
|
|
mock_mcp = MagicMock()
|
|
|
|
engine.set_mcp_manager(mock_mcp)
|
|
|
|
assert engine._mcp is mock_mcp
|
|
|
|
def test_set_redis(self) -> None:
|
|
"""Test setting Redis connection."""
|
|
engine = ContextEngine()
|
|
mock_redis = MagicMock()
|
|
|
|
engine.set_redis(mock_redis)
|
|
|
|
assert engine._cache._redis is mock_redis
|
|
|
|
|
|
class TestContextEngineHelpers:
|
|
"""Tests for ContextEngine helper methods."""
|
|
|
|
def test_convert_conversation(self) -> None:
|
|
"""Test converting conversation history."""
|
|
engine = ContextEngine()
|
|
|
|
history = [
|
|
{"role": "user", "content": "Hello!"},
|
|
{"role": "assistant", "content": "Hi there!"},
|
|
{"role": "user", "content": "How are you?"},
|
|
]
|
|
|
|
contexts = engine._convert_conversation(history)
|
|
|
|
assert len(contexts) == 3
|
|
assert all(isinstance(c, ConversationContext) for c in contexts)
|
|
assert contexts[0].role == MessageRole.USER
|
|
assert contexts[1].role == MessageRole.ASSISTANT
|
|
assert contexts[0].content == "Hello!"
|
|
assert contexts[0].metadata["turn"] == 0
|
|
|
|
def test_convert_tool_results(self) -> None:
|
|
"""Test converting tool results."""
|
|
engine = ContextEngine()
|
|
|
|
results = [
|
|
{"tool_name": "search", "content": "Result 1", "status": "success"},
|
|
{"tool_name": "read", "result": {"file": "test.txt"}, "status": "success"},
|
|
]
|
|
|
|
contexts = engine._convert_tool_results(results)
|
|
|
|
assert len(contexts) == 2
|
|
assert all(isinstance(c, ToolContext) for c in contexts)
|
|
assert contexts[0].content == "Result 1"
|
|
assert contexts[0].metadata["tool_name"] == "search"
|
|
# Dict content should be JSON serialized
|
|
assert "file" in contexts[1].content
|
|
assert "test.txt" in contexts[1].content
|
|
|
|
|
|
class TestContextEngineAssembly:
|
|
"""Tests for context assembly."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_minimal(self) -> None:
|
|
"""Test assembling with minimal inputs."""
|
|
engine = ContextEngine()
|
|
|
|
result = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test query",
|
|
model="claude-3-sonnet",
|
|
use_cache=False, # Disable cache for test
|
|
)
|
|
|
|
assert isinstance(result, AssembledContext)
|
|
assert result.context_count == 0 # No contexts provided
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_with_system_prompt(self) -> None:
|
|
"""Test assembling with system prompt."""
|
|
engine = ContextEngine()
|
|
|
|
result = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test query",
|
|
model="claude-3-sonnet",
|
|
system_prompt="You are a helpful assistant.",
|
|
use_cache=False,
|
|
)
|
|
|
|
assert result.context_count == 1
|
|
assert "helpful assistant" in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_with_task(self) -> None:
|
|
"""Test assembling with task description."""
|
|
engine = ContextEngine()
|
|
|
|
result = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="implement feature",
|
|
model="claude-3-sonnet",
|
|
task_description="Implement user authentication",
|
|
use_cache=False,
|
|
)
|
|
|
|
assert result.context_count == 1
|
|
assert "authentication" in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_with_conversation(self) -> None:
|
|
"""Test assembling with conversation history."""
|
|
engine = ContextEngine()
|
|
|
|
result = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="continue",
|
|
model="claude-3-sonnet",
|
|
conversation_history=[
|
|
{"role": "user", "content": "Hello!"},
|
|
{"role": "assistant", "content": "Hi!"},
|
|
],
|
|
use_cache=False,
|
|
)
|
|
|
|
assert result.context_count == 2
|
|
assert "Hello" in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_with_tool_results(self) -> None:
|
|
"""Test assembling with tool results."""
|
|
engine = ContextEngine()
|
|
|
|
result = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="continue",
|
|
model="claude-3-sonnet",
|
|
tool_results=[
|
|
{"tool_name": "search", "content": "Found 5 results"},
|
|
],
|
|
use_cache=False,
|
|
)
|
|
|
|
assert result.context_count == 1
|
|
assert "Found 5 results" in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_with_custom_contexts(self) -> None:
|
|
"""Test assembling with custom contexts."""
|
|
engine = ContextEngine()
|
|
|
|
custom = [
|
|
KnowledgeContext(
|
|
content="Custom knowledge.",
|
|
source="custom",
|
|
relevance_score=0.9,
|
|
)
|
|
]
|
|
|
|
result = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
custom_contexts=custom,
|
|
use_cache=False,
|
|
)
|
|
|
|
assert result.context_count == 1
|
|
assert "Custom knowledge" in result.content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_assemble_full_workflow(self) -> None:
|
|
"""Test full assembly workflow."""
|
|
engine = ContextEngine()
|
|
|
|
result = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="implement login",
|
|
model="claude-3-sonnet",
|
|
system_prompt="You are an expert Python developer.",
|
|
task_description="Implement user authentication.",
|
|
conversation_history=[
|
|
{"role": "user", "content": "Can you help me implement JWT auth?"},
|
|
],
|
|
tool_results=[
|
|
{"tool_name": "file_create", "content": "Created auth.py"},
|
|
],
|
|
use_cache=False,
|
|
)
|
|
|
|
assert result.context_count >= 4
|
|
assert result.total_tokens > 0
|
|
assert result.model == "claude-3-sonnet"
|
|
|
|
# Check for expected content
|
|
assert "expert Python developer" in result.content
|
|
assert "authentication" in result.content
|
|
|
|
|
|
class TestContextEngineKnowledge:
|
|
"""Tests for knowledge fetching."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_knowledge_no_mcp(self) -> None:
|
|
"""Test fetching knowledge without MCP returns empty."""
|
|
engine = ContextEngine()
|
|
|
|
result = await engine._fetch_knowledge(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test",
|
|
)
|
|
|
|
assert result == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_knowledge_with_mcp(self) -> None:
|
|
"""Test fetching knowledge with MCP."""
|
|
mock_mcp = AsyncMock()
|
|
mock_mcp.call_tool.return_value.data = {
|
|
"results": [
|
|
{
|
|
"content": "Document content",
|
|
"source_path": "docs/api.md",
|
|
"score": 0.9,
|
|
"chunk_id": "chunk-1",
|
|
},
|
|
{
|
|
"content": "Another document",
|
|
"source_path": "docs/auth.md",
|
|
"score": 0.8,
|
|
},
|
|
]
|
|
}
|
|
|
|
engine = ContextEngine(mcp_manager=mock_mcp)
|
|
|
|
result = await engine._fetch_knowledge(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="authentication",
|
|
)
|
|
|
|
assert len(result) == 2
|
|
assert all(isinstance(c, KnowledgeContext) for c in result)
|
|
assert result[0].content == "Document content"
|
|
assert result[0].source == "docs/api.md"
|
|
assert result[0].relevance_score == 0.9
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_knowledge_error_handling(self) -> None:
|
|
"""Test knowledge fetch error handling."""
|
|
mock_mcp = AsyncMock()
|
|
mock_mcp.call_tool.side_effect = Exception("MCP error")
|
|
|
|
engine = ContextEngine(mcp_manager=mock_mcp)
|
|
|
|
# Should not raise, returns empty
|
|
result = await engine._fetch_knowledge(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test",
|
|
)
|
|
|
|
assert result == []
|
|
|
|
|
|
class TestContextEngineCaching:
|
|
"""Tests for caching behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_disabled(self) -> None:
|
|
"""Test assembly with cache disabled."""
|
|
engine = ContextEngine()
|
|
|
|
result = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
system_prompt="Test prompt",
|
|
use_cache=False,
|
|
)
|
|
|
|
assert not result.cache_hit
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_hit(self) -> None:
|
|
"""Test cache hit."""
|
|
mock_redis = AsyncMock()
|
|
settings = ContextSettings(cache_enabled=True)
|
|
engine = ContextEngine(redis=mock_redis, settings=settings)
|
|
|
|
# First call - cache miss
|
|
mock_redis.get.return_value = None
|
|
|
|
result1 = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
system_prompt="Test prompt",
|
|
)
|
|
|
|
# Second call - mock cache hit
|
|
mock_redis.get.return_value = result1.to_json()
|
|
|
|
result2 = await engine.assemble_context(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test",
|
|
model="claude-3-sonnet",
|
|
system_prompt="Test prompt",
|
|
)
|
|
|
|
assert result2.cache_hit
|
|
|
|
|
|
class TestContextEngineUtilities:
|
|
"""Tests for utility methods."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_budget_for_model(self) -> None:
|
|
"""Test getting budget for model."""
|
|
engine = ContextEngine()
|
|
|
|
budget = await engine.get_budget_for_model("claude-3-sonnet")
|
|
|
|
assert budget.total > 0
|
|
assert budget.system > 0
|
|
assert budget.knowledge > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_budget_with_max_tokens(self) -> None:
|
|
"""Test getting budget with max tokens."""
|
|
engine = ContextEngine()
|
|
|
|
budget = await engine.get_budget_for_model("claude-3-sonnet", max_tokens=5000)
|
|
|
|
assert budget.total == 5000
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_count_tokens(self) -> None:
|
|
"""Test token counting."""
|
|
engine = ContextEngine()
|
|
|
|
count = await engine.count_tokens("Hello world")
|
|
|
|
assert count > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalidate_cache(self) -> None:
|
|
"""Test cache invalidation."""
|
|
mock_redis = AsyncMock()
|
|
|
|
async def mock_scan_iter(match=None):
|
|
for key in ["ctx:1", "ctx:2"]:
|
|
yield key
|
|
|
|
mock_redis.scan_iter = mock_scan_iter
|
|
|
|
settings = ContextSettings(cache_enabled=True)
|
|
engine = ContextEngine(redis=mock_redis, settings=settings)
|
|
|
|
deleted = await engine.invalidate_cache(pattern="*test*")
|
|
|
|
assert deleted >= 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_stats(self) -> None:
|
|
"""Test getting engine stats."""
|
|
engine = ContextEngine()
|
|
|
|
stats = await engine.get_stats()
|
|
|
|
assert "cache" in stats
|
|
assert "settings" in stats
|
|
assert "compression_threshold" in stats["settings"]
|
|
|
|
|
|
class TestCreateContextEngine:
|
|
"""Tests for factory function."""
|
|
|
|
def test_create_context_engine(self) -> None:
|
|
"""Test factory function."""
|
|
engine = create_context_engine()
|
|
|
|
assert isinstance(engine, ContextEngine)
|
|
|
|
def test_create_context_engine_with_settings(self) -> None:
|
|
"""Test factory with settings."""
|
|
settings = ContextSettings(cache_enabled=False)
|
|
engine = create_context_engine(settings=settings)
|
|
|
|
assert engine._settings.cache_enabled is False
|