feat(context): Phase 1 - Foundation types, config and exceptions (#79)
Implements the foundation for Context Management Engine: Types (backend/app/services/context/types/): - BaseContext: Abstract base with ID, content, priority, scoring - SystemContext: System prompts, personas, instructions - KnowledgeContext: RAG results from Knowledge Base MCP - ConversationContext: Chat history with role support - TaskContext: Task/issue context with acceptance criteria - ToolContext: Tool definitions and execution results - AssembledContext: Final assembled context result Configuration (config.py): - Token budget allocation (system 5%, task 10%, knowledge 40%, etc.) - Scoring weights (relevance 50%, recency 30%, priority 20%) - Cache settings (TTL, prefix) - Performance settings (max assembly time, parallel scoring) - Environment variable overrides with CTX_ prefix Exceptions (exceptions.py): - ContextError: Base exception - BudgetExceededError: Token budget violations - TokenCountError: Token counting failures - CompressionError: Compression failures - AssemblyTimeoutError: Assembly timeout - ScoringError, FormattingError, CacheError - ContextNotFoundError, InvalidContextError All 86 tests pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
1
backend/tests/services/context/__init__.py
Normal file
1
backend/tests/services/context/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for Context Management Engine."""
|
||||
243
backend/tests/services/context/test_config.py
Normal file
243
backend/tests/services/context/test_config.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Tests for context management configuration."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.config import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
get_default_settings,
|
||||
reset_context_settings,
|
||||
)
|
||||
|
||||
|
||||
class TestContextSettings:
|
||||
"""Tests for ContextSettings."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default settings values."""
|
||||
settings = ContextSettings()
|
||||
|
||||
# Budget defaults should sum to 1.0
|
||||
total = (
|
||||
settings.budget_system
|
||||
+ settings.budget_task
|
||||
+ settings.budget_knowledge
|
||||
+ settings.budget_conversation
|
||||
+ settings.budget_tools
|
||||
+ settings.budget_response
|
||||
+ settings.budget_buffer
|
||||
)
|
||||
assert abs(total - 1.0) < 0.001
|
||||
|
||||
# Scoring weights should sum to 1.0
|
||||
weights_total = (
|
||||
settings.scoring_relevance_weight
|
||||
+ settings.scoring_recency_weight
|
||||
+ settings.scoring_priority_weight
|
||||
)
|
||||
assert abs(weights_total - 1.0) < 0.001
|
||||
|
||||
def test_budget_allocation_values(self) -> None:
|
||||
"""Test specific budget allocation values."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.budget_system == 0.05
|
||||
assert settings.budget_task == 0.10
|
||||
assert settings.budget_knowledge == 0.40
|
||||
assert settings.budget_conversation == 0.20
|
||||
assert settings.budget_tools == 0.05
|
||||
assert settings.budget_response == 0.15
|
||||
assert settings.budget_buffer == 0.05
|
||||
|
||||
def test_scoring_weights(self) -> None:
|
||||
"""Test scoring weights."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.scoring_relevance_weight == 0.5
|
||||
assert settings.scoring_recency_weight == 0.3
|
||||
assert settings.scoring_priority_weight == 0.2
|
||||
|
||||
def test_cache_settings(self) -> None:
|
||||
"""Test cache settings."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.cache_enabled is True
|
||||
assert settings.cache_ttl_seconds == 3600
|
||||
assert settings.cache_prefix == "ctx"
|
||||
|
||||
def test_performance_settings(self) -> None:
|
||||
"""Test performance settings."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.max_assembly_time_ms == 100
|
||||
assert settings.parallel_scoring is True
|
||||
assert settings.max_parallel_scores == 10
|
||||
|
||||
def test_get_budget_allocation(self) -> None:
|
||||
"""Test get_budget_allocation method."""
|
||||
settings = ContextSettings()
|
||||
allocation = settings.get_budget_allocation()
|
||||
|
||||
assert isinstance(allocation, dict)
|
||||
assert "system" in allocation
|
||||
assert "knowledge" in allocation
|
||||
assert allocation["system"] == 0.05
|
||||
assert allocation["knowledge"] == 0.40
|
||||
|
||||
def test_get_scoring_weights(self) -> None:
|
||||
"""Test get_scoring_weights method."""
|
||||
settings = ContextSettings()
|
||||
weights = settings.get_scoring_weights()
|
||||
|
||||
assert isinstance(weights, dict)
|
||||
assert "relevance" in weights
|
||||
assert "recency" in weights
|
||||
assert "priority" in weights
|
||||
assert weights["relevance"] == 0.5
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict method."""
|
||||
settings = ContextSettings()
|
||||
result = settings.to_dict()
|
||||
|
||||
assert "budget" in result
|
||||
assert "scoring" in result
|
||||
assert "compression" in result
|
||||
assert "cache" in result
|
||||
assert "performance" in result
|
||||
assert "knowledge" in result
|
||||
assert "conversation" in result
|
||||
|
||||
def test_budget_validation_fails_on_wrong_sum(self) -> None:
|
||||
"""Test that budget validation fails when sum != 1.0."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ContextSettings(
|
||||
budget_system=0.5,
|
||||
budget_task=0.5,
|
||||
# Other budgets default to non-zero, so total > 1.0
|
||||
)
|
||||
|
||||
assert "sum to 1.0" in str(exc_info.value)
|
||||
|
||||
def test_scoring_validation_fails_on_wrong_sum(self) -> None:
|
||||
"""Test that scoring validation fails when sum != 1.0."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ContextSettings(
|
||||
scoring_relevance_weight=0.8,
|
||||
scoring_recency_weight=0.8,
|
||||
scoring_priority_weight=0.8,
|
||||
)
|
||||
|
||||
assert "sum to 1.0" in str(exc_info.value)
|
||||
|
||||
def test_search_type_validation(self) -> None:
|
||||
"""Test search type validation."""
|
||||
# Valid types should work
|
||||
ContextSettings(knowledge_search_type="semantic")
|
||||
ContextSettings(knowledge_search_type="keyword")
|
||||
ContextSettings(knowledge_search_type="hybrid")
|
||||
|
||||
# Invalid type should fail
|
||||
with pytest.raises(ValueError):
|
||||
ContextSettings(knowledge_search_type="invalid")
|
||||
|
||||
def test_custom_budget_allocation(self) -> None:
|
||||
"""Test custom budget allocation that sums to 1.0."""
|
||||
settings = ContextSettings(
|
||||
budget_system=0.10,
|
||||
budget_task=0.10,
|
||||
budget_knowledge=0.30,
|
||||
budget_conversation=0.25,
|
||||
budget_tools=0.05,
|
||||
budget_response=0.15,
|
||||
budget_buffer=0.05,
|
||||
)
|
||||
|
||||
total = (
|
||||
settings.budget_system
|
||||
+ settings.budget_task
|
||||
+ settings.budget_knowledge
|
||||
+ settings.budget_conversation
|
||||
+ settings.budget_tools
|
||||
+ settings.budget_response
|
||||
+ settings.budget_buffer
|
||||
)
|
||||
assert abs(total - 1.0) < 0.001
|
||||
|
||||
|
||||
class TestSettingsSingleton:
|
||||
"""Tests for settings singleton pattern."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Reset settings before each test."""
|
||||
reset_context_settings()
|
||||
|
||||
def teardown_method(self) -> None:
|
||||
"""Clean up after each test."""
|
||||
reset_context_settings()
|
||||
|
||||
def test_get_context_settings_returns_instance(self) -> None:
|
||||
"""Test that get_context_settings returns a settings instance."""
|
||||
settings = get_context_settings()
|
||||
assert isinstance(settings, ContextSettings)
|
||||
|
||||
def test_get_context_settings_returns_same_instance(self) -> None:
|
||||
"""Test that get_context_settings returns the same instance."""
|
||||
settings1 = get_context_settings()
|
||||
settings2 = get_context_settings()
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_reset_creates_new_instance(self) -> None:
|
||||
"""Test that reset creates a new instance."""
|
||||
settings1 = get_context_settings()
|
||||
reset_context_settings()
|
||||
settings2 = get_context_settings()
|
||||
|
||||
# Should be different instances
|
||||
assert settings1 is not settings2
|
||||
|
||||
def test_get_default_settings_cached(self) -> None:
|
||||
"""Test that get_default_settings is cached."""
|
||||
settings1 = get_default_settings()
|
||||
settings2 = get_default_settings()
|
||||
assert settings1 is settings2
|
||||
|
||||
|
||||
class TestEnvironmentOverrides:
|
||||
"""Tests for environment variable overrides."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Reset settings before each test."""
|
||||
reset_context_settings()
|
||||
|
||||
def teardown_method(self) -> None:
|
||||
"""Clean up after each test."""
|
||||
reset_context_settings()
|
||||
# Clean up any env vars we set
|
||||
for key in list(os.environ.keys()):
|
||||
if key.startswith("CTX_"):
|
||||
del os.environ[key]
|
||||
|
||||
def test_env_override_cache_enabled(self) -> None:
|
||||
"""Test that CTX_CACHE_ENABLED env var works."""
|
||||
with patch.dict(os.environ, {"CTX_CACHE_ENABLED": "false"}):
|
||||
reset_context_settings()
|
||||
settings = ContextSettings()
|
||||
assert settings.cache_enabled is False
|
||||
|
||||
def test_env_override_cache_ttl(self) -> None:
|
||||
"""Test that CTX_CACHE_TTL_SECONDS env var works."""
|
||||
with patch.dict(os.environ, {"CTX_CACHE_TTL_SECONDS": "7200"}):
|
||||
reset_context_settings()
|
||||
settings = ContextSettings()
|
||||
assert settings.cache_ttl_seconds == 7200
|
||||
|
||||
def test_env_override_max_assembly_time(self) -> None:
|
||||
"""Test that CTX_MAX_ASSEMBLY_TIME_MS env var works."""
|
||||
with patch.dict(os.environ, {"CTX_MAX_ASSEMBLY_TIME_MS": "200"}):
|
||||
reset_context_settings()
|
||||
settings = ContextSettings()
|
||||
assert settings.max_assembly_time_ms == 200
|
||||
252
backend/tests/services/context/test_exceptions.py
Normal file
252
backend/tests/services/context/test_exceptions.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Tests for context management exceptions."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.exceptions import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
CacheError,
|
||||
CompressionError,
|
||||
ContextError,
|
||||
ContextNotFoundError,
|
||||
FormattingError,
|
||||
InvalidContextError,
|
||||
ScoringError,
|
||||
TokenCountError,
|
||||
)
|
||||
|
||||
|
||||
class TestContextError:
|
||||
"""Tests for base ContextError."""
|
||||
|
||||
def test_basic_initialization(self) -> None:
|
||||
"""Test basic error initialization."""
|
||||
error = ContextError("Test error")
|
||||
assert error.message == "Test error"
|
||||
assert error.details == {}
|
||||
assert str(error) == "Test error"
|
||||
|
||||
def test_with_details(self) -> None:
|
||||
"""Test error with details."""
|
||||
error = ContextError("Test error", {"key": "value", "count": 42})
|
||||
assert error.details == {"key": "value", "count": 42}
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
error = ContextError("Test error", {"key": "value"})
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["error_type"] == "ContextError"
|
||||
assert result["message"] == "Test error"
|
||||
assert result["details"] == {"key": "value"}
|
||||
|
||||
def test_inheritance(self) -> None:
|
||||
"""Test that ContextError inherits from Exception."""
|
||||
error = ContextError("Test")
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
class TestBudgetExceededError:
|
||||
"""Tests for BudgetExceededError."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = BudgetExceededError()
|
||||
assert "exceeded" in error.message.lower()
|
||||
|
||||
def test_with_budget_info(self) -> None:
|
||||
"""Test with budget information."""
|
||||
error = BudgetExceededError(
|
||||
allocated=1000,
|
||||
requested=1500,
|
||||
context_type="knowledge",
|
||||
)
|
||||
|
||||
assert error.allocated == 1000
|
||||
assert error.requested == 1500
|
||||
assert error.context_type == "knowledge"
|
||||
assert error.details["overage"] == 500
|
||||
|
||||
def test_to_dict_includes_budget_info(self) -> None:
|
||||
"""Test that to_dict includes budget info."""
|
||||
error = BudgetExceededError(
|
||||
allocated=1000,
|
||||
requested=1500,
|
||||
)
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["details"]["allocated"] == 1000
|
||||
assert result["details"]["requested"] == 1500
|
||||
assert result["details"]["overage"] == 500
|
||||
|
||||
|
||||
class TestTokenCountError:
|
||||
"""Tests for TokenCountError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic token count error."""
|
||||
error = TokenCountError()
|
||||
assert "token" in error.message.lower()
|
||||
|
||||
def test_with_model_info(self) -> None:
|
||||
"""Test with model information."""
|
||||
error = TokenCountError(
|
||||
message="Failed to count",
|
||||
model="claude-3-sonnet",
|
||||
text_length=5000,
|
||||
)
|
||||
|
||||
assert error.model == "claude-3-sonnet"
|
||||
assert error.text_length == 5000
|
||||
assert error.details["model"] == "claude-3-sonnet"
|
||||
|
||||
|
||||
class TestCompressionError:
|
||||
"""Tests for CompressionError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic compression error."""
|
||||
error = CompressionError()
|
||||
assert "compress" in error.message.lower()
|
||||
|
||||
def test_with_token_info(self) -> None:
|
||||
"""Test with token information."""
|
||||
error = CompressionError(
|
||||
original_tokens=2000,
|
||||
target_tokens=1000,
|
||||
achieved_tokens=1500,
|
||||
)
|
||||
|
||||
assert error.original_tokens == 2000
|
||||
assert error.target_tokens == 1000
|
||||
assert error.achieved_tokens == 1500
|
||||
|
||||
|
||||
class TestAssemblyTimeoutError:
|
||||
"""Tests for AssemblyTimeoutError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic timeout error."""
|
||||
error = AssemblyTimeoutError()
|
||||
assert "timed out" in error.message.lower()
|
||||
|
||||
def test_with_timing_info(self) -> None:
|
||||
"""Test with timing information."""
|
||||
error = AssemblyTimeoutError(
|
||||
timeout_ms=100,
|
||||
elapsed_ms=150.5,
|
||||
stage="scoring",
|
||||
)
|
||||
|
||||
assert error.timeout_ms == 100
|
||||
assert error.elapsed_ms == 150.5
|
||||
assert error.stage == "scoring"
|
||||
assert error.details["stage"] == "scoring"
|
||||
|
||||
|
||||
class TestScoringError:
|
||||
"""Tests for ScoringError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic scoring error."""
|
||||
error = ScoringError()
|
||||
assert "score" in error.message.lower()
|
||||
|
||||
def test_with_scorer_info(self) -> None:
|
||||
"""Test with scorer information."""
|
||||
error = ScoringError(
|
||||
scorer_type="relevance",
|
||||
context_id="ctx-123",
|
||||
)
|
||||
|
||||
assert error.scorer_type == "relevance"
|
||||
assert error.context_id == "ctx-123"
|
||||
|
||||
|
||||
class TestFormattingError:
|
||||
"""Tests for FormattingError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic formatting error."""
|
||||
error = FormattingError()
|
||||
assert "format" in error.message.lower()
|
||||
|
||||
def test_with_model_info(self) -> None:
|
||||
"""Test with model information."""
|
||||
error = FormattingError(
|
||||
model="claude-3-opus",
|
||||
adapter="ClaudeAdapter",
|
||||
)
|
||||
|
||||
assert error.model == "claude-3-opus"
|
||||
assert error.adapter == "ClaudeAdapter"
|
||||
|
||||
|
||||
class TestCacheError:
|
||||
"""Tests for CacheError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic cache error."""
|
||||
error = CacheError()
|
||||
assert "cache" in error.message.lower()
|
||||
|
||||
def test_with_operation_info(self) -> None:
|
||||
"""Test with operation information."""
|
||||
error = CacheError(
|
||||
operation="get",
|
||||
cache_key="ctx:abc123",
|
||||
)
|
||||
|
||||
assert error.operation == "get"
|
||||
assert error.cache_key == "ctx:abc123"
|
||||
|
||||
|
||||
class TestContextNotFoundError:
|
||||
"""Tests for ContextNotFoundError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic not found error."""
|
||||
error = ContextNotFoundError()
|
||||
assert "not found" in error.message.lower()
|
||||
|
||||
def test_with_source_info(self) -> None:
|
||||
"""Test with source information."""
|
||||
error = ContextNotFoundError(
|
||||
source="knowledge-base",
|
||||
query="authentication flow",
|
||||
)
|
||||
|
||||
assert error.source == "knowledge-base"
|
||||
assert error.query == "authentication flow"
|
||||
|
||||
|
||||
class TestInvalidContextError:
|
||||
"""Tests for InvalidContextError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic invalid error."""
|
||||
error = InvalidContextError()
|
||||
assert "invalid" in error.message.lower()
|
||||
|
||||
def test_with_field_info(self) -> None:
|
||||
"""Test with field information."""
|
||||
error = InvalidContextError(
|
||||
field="content",
|
||||
value="",
|
||||
reason="Content cannot be empty",
|
||||
)
|
||||
|
||||
assert error.field == "content"
|
||||
assert error.value == ""
|
||||
assert error.reason == "Content cannot be empty"
|
||||
|
||||
def test_value_type_only_in_details(self) -> None:
|
||||
"""Test that only value type is included in details (not actual value)."""
|
||||
error = InvalidContextError(
|
||||
field="api_key",
|
||||
value="secret-key-here",
|
||||
)
|
||||
|
||||
# Actual value should not be in details
|
||||
assert "secret-key-here" not in str(error.details)
|
||||
assert error.details["value_type"] == "str"
|
||||
579
backend/tests/services/context/test_types.py
Normal file
579
backend/tests/services/context/test_types.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""Tests for context types."""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestContextType:
|
||||
"""Tests for ContextType enum."""
|
||||
|
||||
def test_all_types_exist(self) -> None:
|
||||
"""Test that all expected context types exist."""
|
||||
assert ContextType.SYSTEM
|
||||
assert ContextType.TASK
|
||||
assert ContextType.KNOWLEDGE
|
||||
assert ContextType.CONVERSATION
|
||||
assert ContextType.TOOL
|
||||
|
||||
def test_from_string_valid(self) -> None:
|
||||
"""Test from_string with valid values."""
|
||||
assert ContextType.from_string("system") == ContextType.SYSTEM
|
||||
assert ContextType.from_string("KNOWLEDGE") == ContextType.KNOWLEDGE
|
||||
assert ContextType.from_string("Task") == ContextType.TASK
|
||||
|
||||
def test_from_string_invalid(self) -> None:
|
||||
"""Test from_string with invalid value."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ContextType.from_string("invalid")
|
||||
assert "Invalid context type" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestContextPriority:
|
||||
"""Tests for ContextPriority enum."""
|
||||
|
||||
def test_priority_ordering(self) -> None:
|
||||
"""Test that priorities are ordered correctly."""
|
||||
assert ContextPriority.LOWEST.value < ContextPriority.LOW.value
|
||||
assert ContextPriority.LOW.value < ContextPriority.NORMAL.value
|
||||
assert ContextPriority.NORMAL.value < ContextPriority.HIGH.value
|
||||
assert ContextPriority.HIGH.value < ContextPriority.HIGHEST.value
|
||||
assert ContextPriority.HIGHEST.value < ContextPriority.CRITICAL.value
|
||||
|
||||
def test_from_int(self) -> None:
|
||||
"""Test from_int conversion."""
|
||||
assert ContextPriority.from_int(0) == ContextPriority.LOWEST
|
||||
assert ContextPriority.from_int(50) == ContextPriority.NORMAL
|
||||
assert ContextPriority.from_int(100) == ContextPriority.HIGHEST
|
||||
assert ContextPriority.from_int(200) == ContextPriority.CRITICAL
|
||||
|
||||
def test_from_int_intermediate(self) -> None:
|
||||
"""Test from_int with intermediate values."""
|
||||
# Should return closest lower priority
|
||||
assert ContextPriority.from_int(30) == ContextPriority.LOW
|
||||
assert ContextPriority.from_int(60) == ContextPriority.NORMAL
|
||||
|
||||
|
||||
class TestSystemContext:
|
||||
"""Tests for SystemContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = SystemContext(
|
||||
content="You are a helpful assistant.",
|
||||
source="system_prompt",
|
||||
)
|
||||
|
||||
assert ctx.content == "You are a helpful assistant."
|
||||
assert ctx.source == "system_prompt"
|
||||
assert ctx.get_type() == ContextType.SYSTEM
|
||||
|
||||
def test_default_high_priority(self) -> None:
|
||||
"""Test that system context defaults to high priority."""
|
||||
ctx = SystemContext(content="Test", source="test")
|
||||
assert ctx.priority == ContextPriority.HIGH.value
|
||||
|
||||
def test_create_persona(self) -> None:
|
||||
"""Test create_persona factory method."""
|
||||
ctx = SystemContext.create_persona(
|
||||
name="Code Assistant",
|
||||
description="A helpful coding assistant.",
|
||||
capabilities=["Write code", "Debug"],
|
||||
constraints=["Never expose secrets"],
|
||||
)
|
||||
|
||||
assert "Code Assistant" in ctx.content
|
||||
assert "helpful coding assistant" in ctx.content
|
||||
assert "Write code" in ctx.content
|
||||
assert "Never expose secrets" in ctx.content
|
||||
assert ctx.priority == ContextPriority.HIGHEST.value
|
||||
|
||||
def test_create_instructions(self) -> None:
|
||||
"""Test create_instructions factory method."""
|
||||
ctx = SystemContext.create_instructions(
|
||||
["Always be helpful", "Be concise"],
|
||||
source="rules",
|
||||
)
|
||||
|
||||
assert "Always be helpful" in ctx.content
|
||||
assert "Be concise" in ctx.content
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test serialization to dict."""
|
||||
ctx = SystemContext(
|
||||
content="Test",
|
||||
source="test",
|
||||
role="assistant",
|
||||
instructions_type="general",
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
assert data["role"] == "assistant"
|
||||
assert data["instructions_type"] == "general"
|
||||
assert data["type"] == "system"
|
||||
|
||||
|
||||
class TestKnowledgeContext:
|
||||
"""Tests for KnowledgeContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = KnowledgeContext(
|
||||
content="def authenticate(user): ...",
|
||||
source="/src/auth.py",
|
||||
collection="code",
|
||||
file_type="python",
|
||||
)
|
||||
|
||||
assert ctx.content == "def authenticate(user): ..."
|
||||
assert ctx.source == "/src/auth.py"
|
||||
assert ctx.collection == "code"
|
||||
assert ctx.get_type() == ContextType.KNOWLEDGE
|
||||
|
||||
def test_from_search_result(self) -> None:
|
||||
"""Test from_search_result factory method."""
|
||||
result = {
|
||||
"content": "Test content",
|
||||
"source_path": "/test/file.py",
|
||||
"collection": "code",
|
||||
"file_type": "python",
|
||||
"chunk_index": 2,
|
||||
"score": 0.85,
|
||||
"id": "chunk-123",
|
||||
}
|
||||
|
||||
ctx = KnowledgeContext.from_search_result(result, "test query")
|
||||
|
||||
assert ctx.content == "Test content"
|
||||
assert ctx.source == "/test/file.py"
|
||||
assert ctx.relevance_score == 0.85
|
||||
assert ctx.search_query == "test query"
|
||||
|
||||
def test_from_search_results(self) -> None:
|
||||
"""Test from_search_results factory method."""
|
||||
results = [
|
||||
{"content": "Content 1", "source_path": "/a.py", "score": 0.9},
|
||||
{"content": "Content 2", "source_path": "/b.py", "score": 0.8},
|
||||
]
|
||||
|
||||
contexts = KnowledgeContext.from_search_results(results, "query")
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert contexts[0].relevance_score == 0.9
|
||||
assert contexts[1].source == "/b.py"
|
||||
|
||||
def test_is_code(self) -> None:
|
||||
"""Test is_code method."""
|
||||
code_ctx = KnowledgeContext(
|
||||
content="code", source="test", file_type="python"
|
||||
)
|
||||
doc_ctx = KnowledgeContext(
|
||||
content="docs", source="test", file_type="markdown"
|
||||
)
|
||||
|
||||
assert code_ctx.is_code() is True
|
||||
assert doc_ctx.is_code() is False
|
||||
|
||||
def test_is_documentation(self) -> None:
|
||||
"""Test is_documentation method."""
|
||||
doc_ctx = KnowledgeContext(
|
||||
content="docs", source="test", file_type="markdown"
|
||||
)
|
||||
code_ctx = KnowledgeContext(
|
||||
content="code", source="test", file_type="python"
|
||||
)
|
||||
|
||||
assert doc_ctx.is_documentation() is True
|
||||
assert code_ctx.is_documentation() is False
|
||||
|
||||
|
||||
class TestConversationContext:
|
||||
"""Tests for ConversationContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = ConversationContext(
|
||||
content="Hello, how can I help?",
|
||||
source="conversation",
|
||||
role=MessageRole.ASSISTANT,
|
||||
turn_index=1,
|
||||
)
|
||||
|
||||
assert ctx.content == "Hello, how can I help?"
|
||||
assert ctx.role == MessageRole.ASSISTANT
|
||||
assert ctx.get_type() == ContextType.CONVERSATION
|
||||
|
||||
def test_from_message(self) -> None:
|
||||
"""Test from_message factory method."""
|
||||
ctx = ConversationContext.from_message(
|
||||
content="What is Python?",
|
||||
role="user",
|
||||
turn_index=0,
|
||||
)
|
||||
|
||||
assert ctx.content == "What is Python?"
|
||||
assert ctx.role == MessageRole.USER
|
||||
assert ctx.turn_index == 0
|
||||
|
||||
def test_from_history(self) -> None:
|
||||
"""Test from_history factory method."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "Help me"},
|
||||
]
|
||||
|
||||
contexts = ConversationContext.from_history(messages)
|
||||
|
||||
assert len(contexts) == 3
|
||||
assert contexts[0].role == MessageRole.USER
|
||||
assert contexts[1].role == MessageRole.ASSISTANT
|
||||
assert contexts[2].turn_index == 2
|
||||
|
||||
def test_is_user_message(self) -> None:
|
||||
"""Test is_user_message method."""
|
||||
user_ctx = ConversationContext(
|
||||
content="test", source="test", role=MessageRole.USER
|
||||
)
|
||||
assistant_ctx = ConversationContext(
|
||||
content="test", source="test", role=MessageRole.ASSISTANT
|
||||
)
|
||||
|
||||
assert user_ctx.is_user_message() is True
|
||||
assert assistant_ctx.is_user_message() is False
|
||||
|
||||
def test_format_for_prompt(self) -> None:
|
||||
"""Test format_for_prompt method."""
|
||||
ctx = ConversationContext.from_message(
|
||||
content="What is 2+2?",
|
||||
role="user",
|
||||
)
|
||||
|
||||
formatted = ctx.format_for_prompt()
|
||||
assert "User:" in formatted
|
||||
assert "What is 2+2?" in formatted
|
||||
|
||||
|
||||
class TestTaskContext:
|
||||
"""Tests for TaskContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = TaskContext(
|
||||
content="Implement login feature",
|
||||
source="task",
|
||||
title="Login Feature",
|
||||
)
|
||||
|
||||
assert ctx.content == "Implement login feature"
|
||||
assert ctx.title == "Login Feature"
|
||||
assert ctx.get_type() == ContextType.TASK
|
||||
|
||||
def test_default_high_priority(self) -> None:
|
||||
"""Test that task context defaults to high priority."""
|
||||
ctx = TaskContext(content="Test", source="test")
|
||||
assert ctx.priority == ContextPriority.HIGH.value
|
||||
|
||||
def test_create_factory(self) -> None:
|
||||
"""Test create factory method."""
|
||||
ctx = TaskContext.create(
|
||||
title="Add Auth",
|
||||
description="Implement authentication",
|
||||
acceptance_criteria=["Tests pass", "Code reviewed"],
|
||||
constraints=["Use JWT"],
|
||||
issue_id="123",
|
||||
)
|
||||
|
||||
assert ctx.title == "Add Auth"
|
||||
assert ctx.content == "Implement authentication"
|
||||
assert len(ctx.acceptance_criteria) == 2
|
||||
assert "Use JWT" in ctx.constraints
|
||||
assert ctx.status == TaskStatus.IN_PROGRESS
|
||||
|
||||
def test_format_for_prompt(self) -> None:
|
||||
"""Test format_for_prompt method."""
|
||||
ctx = TaskContext.create(
|
||||
title="Test Task",
|
||||
description="Do something",
|
||||
acceptance_criteria=["Works correctly"],
|
||||
)
|
||||
|
||||
formatted = ctx.format_for_prompt()
|
||||
assert "Task: Test Task" in formatted
|
||||
assert "Do something" in formatted
|
||||
assert "Works correctly" in formatted
|
||||
|
||||
def test_status_checks(self) -> None:
|
||||
"""Test status check methods."""
|
||||
pending = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.PENDING
|
||||
)
|
||||
completed = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.COMPLETED
|
||||
)
|
||||
blocked = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.BLOCKED
|
||||
)
|
||||
|
||||
assert pending.is_active() is True
|
||||
assert completed.is_complete() is True
|
||||
assert blocked.is_blocked() is True
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
"""Tests for ToolContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = ToolContext(
|
||||
content="Tool result here",
|
||||
source="tool:search",
|
||||
tool_name="search",
|
||||
)
|
||||
|
||||
assert ctx.tool_name == "search"
|
||||
assert ctx.get_type() == ContextType.TOOL
|
||||
|
||||
def test_from_tool_definition(self) -> None:
|
||||
"""Test from_tool_definition factory method."""
|
||||
ctx = ToolContext.from_tool_definition(
|
||||
name="search_knowledge",
|
||||
description="Search the knowledge base",
|
||||
parameters={
|
||||
"query": {"type": "string", "required": True},
|
||||
"limit": {"type": "integer", "required": False},
|
||||
},
|
||||
server_name="knowledge-base",
|
||||
)
|
||||
|
||||
assert ctx.tool_name == "search_knowledge"
|
||||
assert "Search the knowledge base" in ctx.content
|
||||
assert ctx.is_result is False
|
||||
assert ctx.server_name == "knowledge-base"
|
||||
|
||||
def test_from_tool_result(self) -> None:
|
||||
"""Test from_tool_result factory method."""
|
||||
ctx = ToolContext.from_tool_result(
|
||||
tool_name="search",
|
||||
result={"found": 5, "items": ["a", "b"]},
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
execution_time_ms=150.5,
|
||||
)
|
||||
|
||||
assert ctx.tool_name == "search"
|
||||
assert ctx.is_result is True
|
||||
assert ctx.result_status == ToolResultStatus.SUCCESS
|
||||
assert "found" in ctx.content
|
||||
|
||||
def test_is_successful(self) -> None:
|
||||
"""Test is_successful method."""
|
||||
success = ToolContext.from_tool_result(
|
||||
"test", "ok", ToolResultStatus.SUCCESS
|
||||
)
|
||||
error = ToolContext.from_tool_result(
|
||||
"test", "error", ToolResultStatus.ERROR
|
||||
)
|
||||
|
||||
assert success.is_successful() is True
|
||||
assert error.is_successful() is False
|
||||
|
||||
def test_format_for_prompt(self) -> None:
|
||||
"""Test format_for_prompt method."""
|
||||
ctx = ToolContext.from_tool_result(
|
||||
"search",
|
||||
"Found 3 results",
|
||||
ToolResultStatus.SUCCESS,
|
||||
)
|
||||
|
||||
formatted = ctx.format_for_prompt()
|
||||
assert "Tool Result" in formatted
|
||||
assert "search" in formatted
|
||||
assert "success" in formatted
|
||||
|
||||
|
||||
class TestAssembledContext:
|
||||
"""Tests for AssembledContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = AssembledContext(
|
||||
content="Assembled content here",
|
||||
token_count=500,
|
||||
contexts_included=5,
|
||||
)
|
||||
|
||||
assert ctx.content == "Assembled content here"
|
||||
assert ctx.token_count == 500
|
||||
assert ctx.contexts_included == 5
|
||||
|
||||
def test_budget_utilization(self) -> None:
|
||||
"""Test budget_utilization property."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
token_count=800,
|
||||
contexts_included=5,
|
||||
budget_total=1000,
|
||||
budget_used=800,
|
||||
)
|
||||
|
||||
assert ctx.budget_utilization == 0.8
|
||||
|
||||
def test_budget_utilization_zero_budget(self) -> None:
|
||||
"""Test budget_utilization with zero budget."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
token_count=0,
|
||||
contexts_included=0,
|
||||
budget_total=0,
|
||||
budget_used=0,
|
||||
)
|
||||
|
||||
assert ctx.budget_utilization == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict method."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
token_count=100,
|
||||
contexts_included=2,
|
||||
assembly_time_ms=50.123,
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
assert data["content"] == "test"
|
||||
assert data["token_count"] == 100
|
||||
assert data["assembly_time_ms"] == 50.12 # Rounded
|
||||
|
||||
def test_to_json_and_from_json(self) -> None:
|
||||
"""Test JSON serialization round-trip."""
|
||||
original = AssembledContext(
|
||||
content="Test content",
|
||||
token_count=100,
|
||||
contexts_included=3,
|
||||
contexts_excluded=2,
|
||||
assembly_time_ms=45.5,
|
||||
budget_total=1000,
|
||||
budget_used=100,
|
||||
by_type={"system": 20, "knowledge": 80},
|
||||
cache_hit=True,
|
||||
cache_key="abc123",
|
||||
)
|
||||
|
||||
json_str = original.to_json()
|
||||
restored = AssembledContext.from_json(json_str)
|
||||
|
||||
assert restored.content == original.content
|
||||
assert restored.token_count == original.token_count
|
||||
assert restored.contexts_included == original.contexts_included
|
||||
assert restored.cache_hit == original.cache_hit
|
||||
assert restored.cache_key == original.cache_key
|
||||
|
||||
|
||||
class TestBaseContextMethods:
|
||||
"""Tests for BaseContext methods."""
|
||||
|
||||
def test_get_age_seconds(self) -> None:
|
||||
"""Test get_age_seconds method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=2)
|
||||
ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
|
||||
age = ctx.get_age_seconds()
|
||||
# Should be approximately 2 hours in seconds
|
||||
assert 7100 < age < 7300
|
||||
|
||||
def test_get_age_hours(self) -> None:
|
||||
"""Test get_age_hours method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=5)
|
||||
ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
|
||||
age = ctx.get_age_hours()
|
||||
assert 4.9 < age < 5.1
|
||||
|
||||
def test_is_stale(self) -> None:
|
||||
"""Test is_stale method."""
|
||||
old_time = datetime.now(UTC) - timedelta(days=10)
|
||||
new_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
|
||||
old_ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
new_ctx = SystemContext(
|
||||
content="test", source="test", timestamp=new_time
|
||||
)
|
||||
|
||||
# Default max_age is 168 hours (7 days)
|
||||
assert old_ctx.is_stale() is True
|
||||
assert new_ctx.is_stale() is False
|
||||
|
||||
def test_token_count_property(self) -> None:
|
||||
"""Test token_count property."""
|
||||
ctx = SystemContext(content="test", source="test")
|
||||
|
||||
# Initially None
|
||||
assert ctx.token_count is None
|
||||
|
||||
# Can be set
|
||||
ctx.token_count = 100
|
||||
assert ctx.token_count == 100
|
||||
|
||||
def test_score_property_clamping(self) -> None:
|
||||
"""Test that score is clamped to 0.0-1.0."""
|
||||
ctx = SystemContext(content="test", source="test")
|
||||
|
||||
ctx.score = 1.5
|
||||
assert ctx.score == 1.0
|
||||
|
||||
ctx.score = -0.5
|
||||
assert ctx.score == 0.0
|
||||
|
||||
ctx.score = 0.75
|
||||
assert ctx.score == 0.75
|
||||
|
||||
def test_hash_and_equality(self) -> None:
|
||||
"""Test hash and equality based on ID."""
|
||||
ctx1 = SystemContext(content="test", source="test")
|
||||
ctx2 = SystemContext(content="test", source="test")
|
||||
ctx3 = SystemContext(content="test", source="test")
|
||||
ctx3.id = ctx1.id # Same ID as ctx1
|
||||
|
||||
# Different IDs = not equal
|
||||
assert ctx1 != ctx2
|
||||
|
||||
# Same ID = equal
|
||||
assert ctx1 == ctx3
|
||||
|
||||
# Can be used in sets
|
||||
context_set = {ctx1, ctx2, ctx3}
|
||||
assert len(context_set) == 2 # ctx1 and ctx3 are same
|
||||
|
||||
def test_truncate(self) -> None:
|
||||
"""Test truncate method."""
|
||||
long_content = "word " * 1000 # Long content
|
||||
ctx = SystemContext(content=long_content, source="test")
|
||||
ctx.token_count = 1000
|
||||
|
||||
truncated = ctx.truncate(100)
|
||||
|
||||
assert len(truncated) < len(long_content)
|
||||
assert "[truncated]" in truncated
|
||||
Reference in New Issue
Block a user