10 Commits

Author SHA1 Message Date
Felipe Cardoso
2bea057fb1 chore(context): refactor for consistency, optimize formatting, and simplify logic
- Cleaned up unnecessary comments in `__all__` definitions for better readability.
- Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping).
- Simplified conditional expressions and inline comments for context scoring and ranking.
- Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`).
- Removed unused imports and ensured consistent usage across test files.
- Updated `test_score_not_cached_on_context` to clarify caching behavior.
- Improved truncation strategy logic and marker handling.
2026-01-04 15:23:14 +01:00
Felipe Cardoso
9e54f16e56 test(context): add edge case tests for truncation and scoring concurrency
- Add tests for truncation edge cases, including zero tokens, short content, and marker handling.
- Add concurrency tests for scoring to verify per-context locking and handling of multiple contexts.
2026-01-04 12:38:04 +01:00
Felipe Cardoso
96e6400bd8 feat(context): enhance performance, caching, and settings management
- Replace hard-coded limits with configurable settings (e.g., cache memory size, truncation strategy, relevance settings).
- Optimize parallel execution in token counting, scoring, and reranking for source diversity.
- Improve caching logic:
  - Add per-context locks for safe parallel scoring.
  - Reuse precomputed fingerprints for cache efficiency.
- Make truncation, scoring, and ranker behaviors fully configurable via settings.
- Add support for middle truncation, context hash-based hashing, and dynamic token limiting.
- Refactor methods for scalability and better error handling.

Tests: Updated all affected components with additional test cases.
2026-01-04 12:37:58 +01:00
Felipe Cardoso
6c7b72f130 chore(context): apply linter fixes and sort imports (#86)
Phase 8 of Context Management Engine - Final Cleanup:

- Sort __all__ exports alphabetically
- Sort imports per isort conventions
- Fix minor linting issues

Final test results:
- 311 context management tests passing
- 2507 total backend tests passing
- 85% code coverage

Context Management Engine is complete with all 8 phases:
1. Foundation: Types, Config, Exceptions
2. Token Budget Management
3. Context Scoring & Ranking
4. Context Assembly Pipeline
5. Model Adapters (Claude, OpenAI)
6. Caching Layer (Redis + in-memory)
7. Main Engine & Integration
8. Testing & Documentation

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 02:46:56 +01:00
Felipe Cardoso
027ebfc332 feat(context): implement main ContextEngine with full integration (#85)
Phase 7 of Context Management Engine - Main Engine:

- Add ContextEngine as main orchestration class
- Integrate all components: calculator, scorer, ranker, compressor, cache
- Add high-level assemble_context() API with:
  - System prompt support
  - Task description support
  - Knowledge Base integration via MCP
  - Conversation history conversion
  - Tool results conversion
  - Custom contexts support
- Add helper methods:
  - get_budget_for_model()
  - count_tokens() with caching
  - invalidate_cache()
  - get_stats()
- Add create_context_engine() factory function

Tests: 26 new tests, 311 total context tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 02:44:40 +01:00
Felipe Cardoso
c2466ab401 feat(context): implement Redis-based caching layer (#84)
Phase 6 of Context Management Engine - Caching Layer:

- Add ContextCache with Redis integration
- Support fingerprint-based assembled context caching
- Support token count caching (model-specific)
- Support score caching (scorer + context + query)
- Add in-memory fallback with LRU eviction
- Add cache invalidation with pattern matching
- Add cache statistics reporting

Key features:
- Hierarchical cache key structure (ctx:type:hash)
- Automatic TTL expiration
- Memory cache for fast repeated access
- Graceful degradation when Redis unavailable

Tests: 29 new tests, 285 total context tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 02:41:21 +01:00
Felipe Cardoso
7828d35e06 feat(context): implement model adapters for Claude and OpenAI (#83)
Phase 5 of Context Management Engine - Model Adapters:

- Add ModelAdapter abstract base class with model matching
- Add DefaultAdapter for unknown models (plain text)
- Add ClaudeAdapter with XML-based formatting:
  - <system_instructions> for system context
  - <reference_documents>/<document> for knowledge
  - <conversation_history>/<message> for chat
  - <tool_results>/<tool_result> for tool outputs
  - XML escaping for special characters
- Add OpenAIAdapter with markdown formatting:
  - ## headers for sections
  - ### Source headers for documents
  - **ROLE** bold labels for conversation
  - Code blocks for tool outputs
- Add get_adapter() factory function for model selection

Tests: 33 new tests, 256 total context tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 02:36:32 +01:00
Felipe Cardoso
6b07e62f00 feat(context): implement assembly pipeline and compression (#82)
Phase 4 of Context Management Engine - Assembly Pipeline:

- Add TruncationStrategy with end/middle/sentence-aware truncation
- Add TruncationResult dataclass for tracking compression metrics
- Add ContextCompressor for type-specific compression
- Add ContextPipeline orchestrating full assembly workflow:
  - Token counting for all contexts
  - Scoring and ranking via ContextRanker
  - Optional compression when budget threshold exceeded
  - Model-specific formatting (XML for Claude, markdown for OpenAI)
- Add PipelineMetrics for performance tracking
- Update AssembledContext with new fields (model, contexts, metadata)
- Add backward compatibility aliases for renamed fields

Tests: 34 new tests, 223 total context tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 02:32:25 +01:00
Felipe Cardoso
0d2005ddcb feat(context): implement context scoring and ranking (Phase 3)
Add comprehensive scoring system with three strategies:
- RelevanceScorer: Semantic similarity with keyword fallback
- RecencyScorer: Exponential decay with type-specific half-lives
- PriorityScorer: Priority-based scoring with type bonuses

Implement CompositeScorer combining all strategies with configurable
weights (default: 50% relevance, 30% recency, 20% priority).

Add ContextRanker for budget-aware context selection with:
- Greedy selection algorithm respecting token budgets
- CRITICAL priority contexts always included
- Diversity reranking to prevent source dominance
- Comprehensive selection statistics

68 tests covering all scoring and ranking functionality.

Part of #61 - Context Management Engine

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 02:24:06 +01:00
Felipe Cardoso
dfa75e682e feat(context): implement token budget management (Phase 2)
Add TokenCalculator with LLM Gateway integration for accurate token
counting with in-memory caching and fallback character-based estimation.
Implement TokenBudget for tracking allocations per context type with
budget enforcement, and BudgetAllocator for creating budgets based on
model context window sizes.

- TokenCalculator: MCP integration, caching, model-specific ratios
- TokenBudget: allocation tracking, can_fit/allocate/deallocate/reset
- BudgetAllocator: model context sizes, budget creation and adjustment
- 35 comprehensive tests covering all budget functionality

Part of #61 - Context Management Engine

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 02:13:23 +01:00
40 changed files with 8577 additions and 108 deletions

View File

@@ -14,11 +14,18 @@ Usage:
ConversationContext, ConversationContext,
TaskContext, TaskContext,
ToolContext, ToolContext,
TokenBudget,
BudgetAllocator,
TokenCalculator,
) )
# Get settings # Get settings
settings = get_context_settings() settings = get_context_settings()
# Create budget for a model
allocator = BudgetAllocator(settings)
budget = allocator.create_budget_for_model("claude-3-sonnet")
# Create context instances # Create context instances
system_ctx = SystemContext.create_persona( system_ctx = SystemContext.create_persona(
name="Code Assistant", name="Code Assistant",
@@ -27,6 +34,37 @@ Usage:
) )
""" """
# Budget Management
# Adapters
from .adapters import (
ClaudeAdapter,
DefaultAdapter,
ModelAdapter,
OpenAIAdapter,
get_adapter,
)
# Assembly
from .assembly import (
ContextPipeline,
PipelineMetrics,
)
from .budget import (
BudgetAllocator,
TokenBudget,
TokenCalculator,
)
# Cache
from .cache import ContextCache
# Compression
from .compression import (
ContextCompressor,
TruncationResult,
TruncationStrategy,
)
# Configuration # Configuration
from .config import ( from .config import (
ContextSettings, ContextSettings,
@@ -35,6 +73,9 @@ from .config import (
reset_context_settings, reset_context_settings,
) )
# Engine
from .engine import ContextEngine, create_context_engine
# Exceptions # Exceptions
from .exceptions import ( from .exceptions import (
AssemblyTimeoutError, AssemblyTimeoutError,
@@ -49,6 +90,22 @@ from .exceptions import (
TokenCountError, TokenCountError,
) )
# Prioritization
from .prioritization import (
ContextRanker,
RankingResult,
)
# Scoring
from .scoring import (
BaseScorer,
CompositeScorer,
PriorityScorer,
RecencyScorer,
RelevanceScorer,
ScoredContext,
)
# Types # Types
from .types import ( from .types import (
AssembledContext, AssembledContext,
@@ -67,39 +124,55 @@ from .types import (
) )
__all__ = [ __all__ = [
# Configuration "AssembledContext",
"ContextSettings",
"get_context_settings",
"get_default_settings",
"reset_context_settings",
# Exceptions
"AssemblyTimeoutError", "AssemblyTimeoutError",
"BaseContext",
"BaseScorer",
"BudgetAllocator",
"BudgetExceededError", "BudgetExceededError",
"CacheError", "CacheError",
"ClaudeAdapter",
"CompositeScorer",
"CompressionError", "CompressionError",
"ContextCache",
"ContextCompressor",
"ContextEngine",
"ContextError", "ContextError",
"ContextNotFoundError", "ContextNotFoundError",
"ContextPipeline",
"ContextPriority",
"ContextRanker",
"ContextSettings",
"ContextType",
"ConversationContext",
"DefaultAdapter",
"FormattingError", "FormattingError",
"InvalidContextError", "InvalidContextError",
"ScoringError",
"TokenCountError",
# Types - Base
"AssembledContext",
"BaseContext",
"ContextPriority",
"ContextType",
# Types - Conversation
"ConversationContext",
"MessageRole",
# Types - Knowledge
"KnowledgeContext", "KnowledgeContext",
# Types - System "MessageRole",
"ModelAdapter",
"OpenAIAdapter",
"PipelineMetrics",
"PriorityScorer",
"RankingResult",
"RecencyScorer",
"RelevanceScorer",
"ScoredContext",
"ScoringError",
"SystemContext", "SystemContext",
# Types - Task
"TaskComplexity", "TaskComplexity",
"TaskContext", "TaskContext",
"TaskStatus", "TaskStatus",
# Types - Tool "TokenBudget",
"TokenCalculator",
"TokenCountError",
"ToolContext", "ToolContext",
"ToolResultStatus", "ToolResultStatus",
"TruncationResult",
"TruncationStrategy",
"create_context_engine",
"get_adapter",
"get_context_settings",
"get_default_settings",
"reset_context_settings",
] ]

View File

@@ -1,5 +1,35 @@
""" """
Model Adapters Module. Model Adapters Module.
Provides model-specific context formatting. Provides model-specific context formatting adapters.
""" """
from .base import DefaultAdapter, ModelAdapter
from .claude import ClaudeAdapter
from .openai import OpenAIAdapter
def get_adapter(model: str) -> ModelAdapter:
"""
Get the appropriate adapter for a model.
Args:
model: Model name
Returns:
Adapter instance for the model
"""
if ClaudeAdapter.matches_model(model):
return ClaudeAdapter()
elif OpenAIAdapter.matches_model(model):
return OpenAIAdapter()
return DefaultAdapter()
__all__ = [
"ClaudeAdapter",
"DefaultAdapter",
"ModelAdapter",
"OpenAIAdapter",
"get_adapter",
]

View File

@@ -0,0 +1,178 @@
"""
Base Model Adapter.
Abstract base class for model-specific context formatting.
"""
from abc import ABC, abstractmethod
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
class ModelAdapter(ABC):
"""
Abstract base adapter for model-specific context formatting.
Each adapter knows how to format contexts for optimal
understanding by a specific LLM family (Claude, OpenAI, etc.).
"""
# Model name patterns this adapter handles
MODEL_PATTERNS: ClassVar[list[str]] = []
@classmethod
def matches_model(cls, model: str) -> bool:
"""
Check if this adapter handles the given model.
Args:
model: Model name to check
Returns:
True if this adapter handles the model
"""
model_lower = model.lower()
return any(pattern in model_lower for pattern in cls.MODEL_PATTERNS)
@abstractmethod
def format(
self,
contexts: list[BaseContext],
**kwargs: Any,
) -> str:
"""
Format contexts for the target model.
Args:
contexts: List of contexts to format
**kwargs: Additional formatting options
Returns:
Formatted context string
"""
...
@abstractmethod
def format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
**kwargs: Any,
) -> str:
"""
Format contexts of a specific type.
Args:
contexts: List of contexts of the same type
context_type: The type of contexts
**kwargs: Additional formatting options
Returns:
Formatted string for this context type
"""
...
def get_type_order(self) -> list[ContextType]:
"""
Get the preferred order of context types.
Returns:
List of context types in preferred order
"""
return [
ContextType.SYSTEM,
ContextType.TASK,
ContextType.KNOWLEDGE,
ContextType.CONVERSATION,
ContextType.TOOL,
]
def group_by_type(
self, contexts: list[BaseContext]
) -> dict[ContextType, list[BaseContext]]:
"""
Group contexts by their type.
Args:
contexts: List of contexts to group
Returns:
Dictionary mapping context type to list of contexts
"""
by_type: dict[ContextType, list[BaseContext]] = {}
for context in contexts:
ct = context.get_type()
if ct not in by_type:
by_type[ct] = []
by_type[ct].append(context)
return by_type
def get_separator(self) -> str:
"""
Get the separator between context sections.
Returns:
Separator string
"""
return "\n\n"
class DefaultAdapter(ModelAdapter):
"""
Default adapter for unknown models.
Uses simple plain-text formatting with minimal structure.
"""
MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
@classmethod
def matches_model(cls, model: str) -> bool:
"""Always returns True as fallback."""
return True
def format(
self,
contexts: list[BaseContext],
**kwargs: Any,
) -> str:
"""Format contexts as plain text."""
if not contexts:
return ""
by_type = self.group_by_type(contexts)
parts: list[str] = []
for ct in self.get_type_order():
if ct in by_type:
formatted = self.format_type(by_type[ct], ct, **kwargs)
if formatted:
parts.append(formatted)
return self.get_separator().join(parts)
def format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
**kwargs: Any,
) -> str:
"""Format contexts of a type as plain text."""
if not contexts:
return ""
content = "\n\n".join(c.content for c in contexts)
if context_type == ContextType.SYSTEM:
return content
elif context_type == ContextType.TASK:
return f"Task:\n{content}"
elif context_type == ContextType.KNOWLEDGE:
return f"Reference Information:\n{content}"
elif context_type == ContextType.CONVERSATION:
return f"Previous Conversation:\n{content}"
elif context_type == ContextType.TOOL:
return f"Tool Results:\n{content}"
return content

View File

@@ -0,0 +1,178 @@
"""
Claude Model Adapter.
Provides Claude-specific context formatting using XML tags
which Claude models understand natively.
"""
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import ModelAdapter
class ClaudeAdapter(ModelAdapter):
"""
Claude-specific context formatting adapter.
Claude models have native understanding of XML structure,
so we use XML tags for clear delineation of context types.
Features:
- XML tags for each context type
- Document structure for knowledge contexts
- Role-based message formatting for conversations
- Tool result wrapping with tool names
"""
MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
def format(
self,
contexts: list[BaseContext],
**kwargs: Any,
) -> str:
"""
Format contexts for Claude models.
Uses XML tags for structured content that Claude
understands natively.
Args:
contexts: List of contexts to format
**kwargs: Additional formatting options
Returns:
XML-structured context string
"""
if not contexts:
return ""
by_type = self.group_by_type(contexts)
parts: list[str] = []
for ct in self.get_type_order():
if ct in by_type:
formatted = self.format_type(by_type[ct], ct, **kwargs)
if formatted:
parts.append(formatted)
return self.get_separator().join(parts)
def format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
**kwargs: Any,
) -> str:
"""
Format contexts of a specific type for Claude.
Args:
contexts: List of contexts of the same type
context_type: The type of contexts
**kwargs: Additional formatting options
Returns:
XML-formatted string for this context type
"""
if not contexts:
return ""
if context_type == ContextType.SYSTEM:
return self._format_system(contexts)
elif context_type == ContextType.TASK:
return self._format_task(contexts)
elif context_type == ContextType.KNOWLEDGE:
return self._format_knowledge(contexts)
elif context_type == ContextType.CONVERSATION:
return self._format_conversation(contexts)
elif context_type == ContextType.TOOL:
return self._format_tool(contexts)
return "\n".join(c.content for c in contexts)
def _format_system(self, contexts: list[BaseContext]) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
return f"<system_instructions>\n{content}\n</system_instructions>"
def _format_task(self, contexts: list[BaseContext]) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
return f"<current_task>\n{content}\n</current_task>"
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
"""
Format knowledge contexts as structured documents.
Each knowledge context becomes a document with source attribution.
"""
parts = ["<reference_documents>"]
for ctx in contexts:
source = self._escape_xml(ctx.source)
content = ctx.content
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
if score:
parts.append(f'<document source="{source}" relevance="{score}">')
else:
parts.append(f'<document source="{source}">')
parts.append(content)
parts.append("</document>")
parts.append("</reference_documents>")
return "\n".join(parts)
def _format_conversation(self, contexts: list[BaseContext]) -> str:
"""
Format conversation contexts as message history.
Uses role-based message tags for clear turn delineation.
"""
parts = ["<conversation_history>"]
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f'<message role="{role}">')
parts.append(ctx.content)
parts.append("</message>")
parts.append("</conversation_history>")
return "\n".join(parts)
def _format_tool(self, contexts: list[BaseContext]) -> str:
"""
Format tool contexts as tool results.
Each tool result is wrapped with the tool name.
"""
parts = ["<tool_results>"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
status = ctx.metadata.get("status", "")
if status:
parts.append(f'<tool_result name="{tool_name}" status="{status}">')
else:
parts.append(f'<tool_result name="{tool_name}">')
parts.append(ctx.content)
parts.append("</tool_result>")
parts.append("</tool_results>")
return "\n".join(parts)
@staticmethod
def _escape_xml(text: str) -> str:
"""Escape XML special characters in attribute values."""
return (
text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)

View File

@@ -0,0 +1,160 @@
"""
OpenAI Model Adapter.
Provides OpenAI-specific context formatting using markdown
which GPT models understand well.
"""
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import ModelAdapter
class OpenAIAdapter(ModelAdapter):
"""
OpenAI-specific context formatting adapter.
GPT models work well with markdown formatting,
so we use headers and structured markdown for clarity.
Features:
- Markdown headers for each context type
- Bulleted lists for document sources
- Bold role labels for conversations
- Code blocks for tool outputs
"""
MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
def format(
self,
contexts: list[BaseContext],
**kwargs: Any,
) -> str:
"""
Format contexts for OpenAI models.
Uses markdown formatting for structured content.
Args:
contexts: List of contexts to format
**kwargs: Additional formatting options
Returns:
Markdown-structured context string
"""
if not contexts:
return ""
by_type = self.group_by_type(contexts)
parts: list[str] = []
for ct in self.get_type_order():
if ct in by_type:
formatted = self.format_type(by_type[ct], ct, **kwargs)
if formatted:
parts.append(formatted)
return self.get_separator().join(parts)
def format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
**kwargs: Any,
) -> str:
"""
Format contexts of a specific type for OpenAI.
Args:
contexts: List of contexts of the same type
context_type: The type of contexts
**kwargs: Additional formatting options
Returns:
Markdown-formatted string for this context type
"""
if not contexts:
return ""
if context_type == ContextType.SYSTEM:
return self._format_system(contexts)
elif context_type == ContextType.TASK:
return self._format_task(contexts)
elif context_type == ContextType.KNOWLEDGE:
return self._format_knowledge(contexts)
elif context_type == ContextType.CONVERSATION:
return self._format_conversation(contexts)
elif context_type == ContextType.TOOL:
return self._format_tool(contexts)
return "\n".join(c.content for c in contexts)
def _format_system(self, contexts: list[BaseContext]) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
return content
def _format_task(self, contexts: list[BaseContext]) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
return f"## Current Task\n\n{content}"
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
"""
Format knowledge contexts as structured documents.
Each knowledge context becomes a section with source attribution.
"""
parts = ["## Reference Documents\n"]
for ctx in contexts:
source = ctx.source
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
if score:
parts.append(f"### Source: {source} (relevance: {score})\n")
else:
parts.append(f"### Source: {source}\n")
parts.append(ctx.content)
parts.append("")
return "\n".join(parts)
def _format_conversation(self, contexts: list[BaseContext]) -> str:
"""
Format conversation contexts as message history.
Uses bold role labels for clear turn delineation.
"""
parts = []
for ctx in contexts:
role = ctx.metadata.get("role", "user").upper()
parts.append(f"**{role}**: {ctx.content}")
return "\n\n".join(parts)
def _format_tool(self, contexts: list[BaseContext]) -> str:
"""
Format tool contexts as tool results.
Each tool result is in a code block with the tool name.
"""
parts = ["## Recent Tool Results\n"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
status = ctx.metadata.get("status", "")
if status:
parts.append(f"### Tool: {tool_name} ({status})\n")
else:
parts.append(f"### Tool: {tool_name}\n")
parts.append(f"```\n{ctx.content}\n```")
parts.append("")
return "\n".join(parts)

View File

@@ -3,3 +3,10 @@ Context Assembly Module.
Provides the assembly pipeline and formatting. Provides the assembly pipeline and formatting.
""" """
from .pipeline import ContextPipeline, PipelineMetrics
__all__ = [
"ContextPipeline",
"PipelineMetrics",
]

View File

@@ -0,0 +1,420 @@
"""
Context Assembly Pipeline.
Orchestrates the full context assembly workflow:
Gather → Count → Score → Rank → Compress → Format
"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
from ..compression.truncation import ContextCompressor
from ..config import ContextSettings, get_context_settings
from ..exceptions import AssemblyTimeoutError
from ..prioritization import ContextRanker
from ..scoring import CompositeScorer
from ..types import AssembledContext, BaseContext, ContextType
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
@dataclass
class PipelineMetrics:
"""Metrics from pipeline execution."""
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
end_time: datetime | None = None
total_contexts: int = 0
selected_contexts: int = 0
excluded_contexts: int = 0
compressed_contexts: int = 0
total_tokens: int = 0
assembly_time_ms: float = 0.0
scoring_time_ms: float = 0.0
compression_time_ms: float = 0.0
formatting_time_ms: float = 0.0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"start_time": self.start_time.isoformat(),
"end_time": self.end_time.isoformat() if self.end_time else None,
"total_contexts": self.total_contexts,
"selected_contexts": self.selected_contexts,
"excluded_contexts": self.excluded_contexts,
"compressed_contexts": self.compressed_contexts,
"total_tokens": self.total_tokens,
"assembly_time_ms": round(self.assembly_time_ms, 2),
"scoring_time_ms": round(self.scoring_time_ms, 2),
"compression_time_ms": round(self.compression_time_ms, 2),
"formatting_time_ms": round(self.formatting_time_ms, 2),
}
class ContextPipeline:
"""
Context assembly pipeline.
Orchestrates the full workflow of context assembly:
1. Validate and count tokens for all contexts
2. Score contexts based on relevance, recency, and priority
3. Rank and select contexts within budget
4. Compress if needed to fit remaining budget
5. Format for the target model
"""
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
settings: ContextSettings | None = None,
calculator: TokenCalculator | None = None,
scorer: CompositeScorer | None = None,
ranker: ContextRanker | None = None,
compressor: ContextCompressor | None = None,
) -> None:
"""
Initialize the context pipeline.
Args:
mcp_manager: MCP client manager for LLM Gateway integration
settings: Context settings
calculator: Token calculator
scorer: Context scorer
ranker: Context ranker
compressor: Context compressor
"""
self._settings = settings or get_context_settings()
self._mcp = mcp_manager
# Initialize components
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
self._scorer = scorer or CompositeScorer(
mcp_manager=mcp_manager, settings=self._settings
)
self._ranker = ranker or ContextRanker(
scorer=self._scorer, calculator=self._calculator
)
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
self._allocator = BudgetAllocator(self._settings)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""Set MCP manager for all components."""
self._mcp = mcp_manager
self._calculator.set_mcp_manager(mcp_manager)
self._scorer.set_mcp_manager(mcp_manager)
async def assemble(
self,
contexts: list[BaseContext],
query: str,
model: str,
max_tokens: int | None = None,
custom_budget: TokenBudget | None = None,
compress: bool = True,
format_output: bool = True,
timeout_ms: int | None = None,
) -> AssembledContext:
"""
Assemble context for an LLM request.
This is the main entry point for context assembly.
Args:
contexts: List of contexts to assemble
query: Query to optimize for
model: Target model name
max_tokens: Maximum total tokens (uses model default if None)
custom_budget: Optional pre-configured budget
compress: Whether to compress oversized contexts
format_output: Whether to format the final output
timeout_ms: Maximum assembly time in milliseconds
Returns:
AssembledContext with optimized content
Raises:
AssemblyTimeoutError: If assembly exceeds timeout
"""
timeout = timeout_ms or self._settings.max_assembly_time_ms
start = time.perf_counter()
metrics = PipelineMetrics(total_contexts=len(contexts))
try:
# Create or use budget
if custom_budget:
budget = custom_budget
elif max_tokens:
budget = self._allocator.create_budget(max_tokens)
else:
budget = self._allocator.create_budget_for_model(model)
# 1. Count tokens for all contexts
await self._ensure_token_counts(contexts, model)
# Check timeout
self._check_timeout(start, timeout, "token counting")
# 2. Score and rank contexts
scoring_start = time.perf_counter()
ranking_result = await self._ranker.rank(
contexts=contexts,
query=query,
budget=budget,
model=model,
)
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
selected_contexts = ranking_result.selected_contexts
metrics.selected_contexts = len(selected_contexts)
metrics.excluded_contexts = len(ranking_result.excluded)
# Check timeout
self._check_timeout(start, timeout, "scoring")
# 3. Compress if needed and enabled
if compress and self._needs_compression(selected_contexts, budget):
compression_start = time.perf_counter()
selected_contexts = await self._compressor.compress_contexts(
selected_contexts, budget, model
)
metrics.compression_time_ms = (
time.perf_counter() - compression_start
) * 1000
metrics.compressed_contexts = sum(
1 for c in selected_contexts if c.metadata.get("truncated", False)
)
# Check timeout
self._check_timeout(start, timeout, "compression")
# 4. Format output
formatting_start = time.perf_counter()
if format_output:
formatted_content = self._format_contexts(selected_contexts, model)
else:
formatted_content = "\n\n".join(c.content for c in selected_contexts)
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
# Calculate final metrics
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
metrics.total_tokens = total_tokens
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
metrics.end_time = datetime.now(UTC)
return AssembledContext(
content=formatted_content,
total_tokens=total_tokens,
context_count=len(selected_contexts),
assembly_time_ms=metrics.assembly_time_ms,
model=model,
contexts=selected_contexts,
excluded_count=metrics.excluded_contexts,
metadata={
"metrics": metrics.to_dict(),
"query": query,
"budget": budget.to_dict(),
},
)
except AssemblyTimeoutError:
raise
except Exception as e:
logger.error(f"Context assembly failed: {e}", exc_info=True)
raise
async def _ensure_token_counts(
self,
contexts: list[BaseContext],
model: str | None = None,
) -> None:
"""Ensure all contexts have token counts."""
tasks = []
for context in contexts:
if context.token_count is None:
tasks.append(self._count_and_set(context, model))
if tasks:
await asyncio.gather(*tasks)
async def _count_and_set(
self,
context: BaseContext,
model: str | None = None,
) -> None:
"""Count tokens and set on context."""
count = await self._calculator.count_tokens(context.content, model)
context.token_count = count
def _needs_compression(
self,
contexts: list[BaseContext],
budget: TokenBudget,
) -> bool:
"""Check if any contexts exceed their type budget."""
# Group by type and check totals
by_type: dict[ContextType, int] = {}
for context in contexts:
ct = context.get_type()
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
for ct, total in by_type.items():
if total > budget.get_allocation(ct):
return True
# Also check if utilization exceeds threshold
return budget.utilization() > self._settings.compression_threshold
def _format_contexts(
self,
contexts: list[BaseContext],
model: str,
) -> str:
"""
Format contexts for the target model.
Groups contexts by type and applies model-specific formatting.
"""
# Group by type
by_type: dict[ContextType, list[BaseContext]] = {}
for context in contexts:
ct = context.get_type()
if ct not in by_type:
by_type[ct] = []
by_type[ct].append(context)
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
type_order = [
ContextType.SYSTEM,
ContextType.TASK,
ContextType.KNOWLEDGE,
ContextType.CONVERSATION,
ContextType.TOOL,
]
parts: list[str] = []
for ct in type_order:
if ct in by_type:
formatted = self._format_type(by_type[ct], ct, model)
if formatted:
parts.append(formatted)
return "\n\n".join(parts)
def _format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
model: str,
) -> str:
"""Format contexts of a specific type."""
if not contexts:
return ""
# Check if model prefers XML tags (Claude)
use_xml = "claude" in model.lower()
if context_type == ContextType.SYSTEM:
return self._format_system(contexts, use_xml)
elif context_type == ContextType.TASK:
return self._format_task(contexts, use_xml)
elif context_type == ContextType.KNOWLEDGE:
return self._format_knowledge(contexts, use_xml)
elif context_type == ContextType.CONVERSATION:
return self._format_conversation(contexts, use_xml)
elif context_type == ContextType.TOOL:
return self._format_tool(contexts, use_xml)
return "\n".join(c.content for c in contexts)
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<system_instructions>\n{content}\n</system_instructions>"
return content
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<current_task>\n{content}\n</current_task>"
return f"## Current Task\n\n{content}"
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format knowledge contexts."""
if use_xml:
parts = ["<reference_documents>"]
for ctx in contexts:
parts.append(f'<document source="{ctx.source}">')
parts.append(ctx.content)
parts.append("</document>")
parts.append("</reference_documents>")
return "\n".join(parts)
else:
parts = ["## Reference Documents\n"]
for ctx in contexts:
parts.append(f"### Source: {ctx.source}\n")
parts.append(ctx.content)
parts.append("")
return "\n".join(parts)
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format conversation contexts."""
if use_xml:
parts = ["<conversation_history>"]
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f'<message role="{role}">')
parts.append(ctx.content)
parts.append("</message>")
parts.append("</conversation_history>")
return "\n".join(parts)
else:
parts = []
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f"**{role.upper()}**: {ctx.content}")
return "\n\n".join(parts)
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format tool contexts."""
if use_xml:
parts = ["<tool_results>"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f'<tool_result name="{tool_name}">')
parts.append(ctx.content)
parts.append("</tool_result>")
parts.append("</tool_results>")
return "\n".join(parts)
else:
parts = ["## Recent Tool Results\n"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f"### Tool: {tool_name}\n")
parts.append(f"```\n{ctx.content}\n```")
parts.append("")
return "\n".join(parts)
def _check_timeout(
self,
start: float,
timeout_ms: int,
phase: str,
) -> None:
"""Check if timeout exceeded and raise if so."""
elapsed_ms = (time.perf_counter() - start) * 1000
if elapsed_ms > timeout_ms:
raise AssemblyTimeoutError(
message=f"Context assembly timed out during {phase}",
elapsed_ms=elapsed_ms,
timeout_ms=timeout_ms,
)

View File

@@ -3,3 +3,12 @@ Token Budget Management Module.
Provides token counting and budget allocation. Provides token counting and budget allocation.
""" """
from .allocator import BudgetAllocator, TokenBudget
from .calculator import TokenCalculator
__all__ = [
"BudgetAllocator",
"TokenBudget",
"TokenCalculator",
]

View File

@@ -0,0 +1,429 @@
"""
Token Budget Allocator for Context Management.
Manages token budget allocation across context types.
"""
from dataclasses import dataclass, field
from typing import Any
from ..config import ContextSettings, get_context_settings
from ..exceptions import BudgetExceededError
from ..types import ContextType
@dataclass
class TokenBudget:
"""
Token budget allocation and tracking.
Tracks allocated tokens per context type and
monitors usage to prevent overflows.
"""
# Total budget
total: int
# Allocated per type
system: int = 0
task: int = 0
knowledge: int = 0
conversation: int = 0
tools: int = 0
response_reserve: int = 0
buffer: int = 0
# Usage tracking
used: dict[str, int] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Initialize usage tracking."""
if not self.used:
self.used = {ct.value: 0 for ct in ContextType}
def get_allocation(self, context_type: ContextType | str) -> int:
"""
Get allocated tokens for a context type.
Args:
context_type: Context type to get allocation for
Returns:
Allocated token count
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
allocation_map = {
"system": self.system,
"task": self.task,
"knowledge": self.knowledge,
"conversation": self.conversation,
"tool": self.tools,
}
return allocation_map.get(context_type, 0)
def get_used(self, context_type: ContextType | str) -> int:
"""
Get used tokens for a context type.
Args:
context_type: Context type to check
Returns:
Used token count
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
return self.used.get(context_type, 0)
def remaining(self, context_type: ContextType | str) -> int:
"""
Get remaining tokens for a context type.
Args:
context_type: Context type to check
Returns:
Remaining token count
"""
allocated = self.get_allocation(context_type)
used = self.get_used(context_type)
return max(0, allocated - used)
def total_remaining(self) -> int:
"""
Get total remaining tokens across all types.
Returns:
Total remaining tokens
"""
total_used = sum(self.used.values())
usable = self.total - self.response_reserve - self.buffer
return max(0, usable - total_used)
def total_used(self) -> int:
"""
Get total used tokens.
Returns:
Total used tokens
"""
return sum(self.used.values())
def can_fit(self, context_type: ContextType | str, tokens: int) -> bool:
"""
Check if tokens fit within budget for a type.
Args:
context_type: Context type to check
tokens: Number of tokens to fit
Returns:
True if tokens fit within remaining budget
"""
return tokens <= self.remaining(context_type)
def allocate(
self,
context_type: ContextType | str,
tokens: int,
force: bool = False,
) -> bool:
"""
Allocate (use) tokens from a context type's budget.
Args:
context_type: Context type to allocate from
tokens: Number of tokens to allocate
force: If True, allow exceeding budget
Returns:
True if allocation succeeded
Raises:
BudgetExceededError: If tokens exceed budget and force=False
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
if not force and not self.can_fit(context_type, tokens):
raise BudgetExceededError(
message=f"Token budget exceeded for {context_type}",
allocated=self.get_allocation(context_type),
requested=self.get_used(context_type) + tokens,
context_type=context_type,
)
self.used[context_type] = self.used.get(context_type, 0) + tokens
return True
def deallocate(
self,
context_type: ContextType | str,
tokens: int,
) -> None:
"""
Deallocate (return) tokens to a context type's budget.
Args:
context_type: Context type to return to
tokens: Number of tokens to return
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
current = self.used.get(context_type, 0)
self.used[context_type] = max(0, current - tokens)
def reset(self) -> None:
"""Reset all usage tracking."""
self.used = {ct.value: 0 for ct in ContextType}
def utilization(self, context_type: ContextType | str | None = None) -> float:
"""
Get budget utilization percentage.
Args:
context_type: Specific type or None for total
Returns:
Utilization as a fraction (0.0 to 1.0+)
"""
if context_type is None:
usable = self.total - self.response_reserve - self.buffer
if usable <= 0:
return 0.0
return self.total_used() / usable
allocated = self.get_allocation(context_type)
if allocated <= 0:
return 0.0
return self.get_used(context_type) / allocated
def to_dict(self) -> dict[str, Any]:
"""Convert budget to dictionary."""
return {
"total": self.total,
"allocations": {
"system": self.system,
"task": self.task,
"knowledge": self.knowledge,
"conversation": self.conversation,
"tools": self.tools,
"response_reserve": self.response_reserve,
"buffer": self.buffer,
},
"used": dict(self.used),
"remaining": {ct.value: self.remaining(ct) for ct in ContextType},
"total_used": self.total_used(),
"total_remaining": self.total_remaining(),
"utilization": round(self.utilization(), 3),
}
class BudgetAllocator:
"""
Budget allocator for context management.
Creates token budgets based on configuration and
model context window sizes.
"""
def __init__(self, settings: ContextSettings | None = None) -> None:
"""
Initialize budget allocator.
Args:
settings: Context settings (uses default if None)
"""
self._settings = settings or get_context_settings()
def create_budget(
self,
total_tokens: int,
custom_allocations: dict[str, float] | None = None,
) -> TokenBudget:
"""
Create a token budget with allocations.
Args:
total_tokens: Total available tokens
custom_allocations: Optional custom allocation percentages
Returns:
TokenBudget with allocations set
"""
# Use custom or default allocations
if custom_allocations:
alloc = custom_allocations
else:
alloc = self._settings.get_budget_allocation()
return TokenBudget(
total=total_tokens,
system=int(total_tokens * alloc.get("system", 0.05)),
task=int(total_tokens * alloc.get("task", 0.10)),
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
tools=int(total_tokens * alloc.get("tools", 0.05)),
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
)
def adjust_budget(
self,
budget: TokenBudget,
context_type: ContextType | str,
adjustment: int,
) -> TokenBudget:
"""
Adjust a specific allocation in a budget.
Takes tokens from buffer and adds to specified type.
Args:
budget: Budget to adjust
context_type: Type to adjust
adjustment: Positive to increase, negative to decrease
Returns:
Adjusted budget
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
# Calculate adjustment (limited by buffer)
if adjustment > 0:
# Taking from buffer
actual_adjustment = min(adjustment, budget.buffer)
budget.buffer -= actual_adjustment
else:
# Returning to buffer
actual_adjustment = adjustment
# Apply to target type
if context_type == "system":
budget.system = max(0, budget.system + actual_adjustment)
elif context_type == "task":
budget.task = max(0, budget.task + actual_adjustment)
elif context_type == "knowledge":
budget.knowledge = max(0, budget.knowledge + actual_adjustment)
elif context_type == "conversation":
budget.conversation = max(0, budget.conversation + actual_adjustment)
elif context_type == "tool":
budget.tools = max(0, budget.tools + actual_adjustment)
return budget
def rebalance_budget(
self,
budget: TokenBudget,
prioritize: list[ContextType] | None = None,
) -> TokenBudget:
"""
Rebalance budget based on actual usage.
Moves unused allocations to prioritized types.
Args:
budget: Budget to rebalance
prioritize: Types to prioritize (in order)
Returns:
Rebalanced budget
"""
if prioritize is None:
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
# Calculate unused tokens per type
unused: dict[str, int] = {}
for ct in ContextType:
remaining = budget.remaining(ct)
if remaining > 0:
unused[ct.value] = remaining
# Calculate total reclaimable (excluding prioritized types)
prioritize_values = {ct.value for ct in prioritize}
reclaimable = sum(
tokens for ct, tokens in unused.items() if ct not in prioritize_values
)
# Redistribute to prioritized types that are near capacity
for ct in prioritize:
utilization = budget.utilization(ct)
if utilization > 0.8: # Near capacity
# Give more tokens from reclaimable pool
bonus = min(reclaimable, budget.get_allocation(ct) // 2)
self.adjust_budget(budget, ct, bonus)
reclaimable -= bonus
if reclaimable <= 0:
break
return budget
def get_model_context_size(self, model: str) -> int:
"""
Get context window size for a model.
Args:
model: Model name
Returns:
Context window size in tokens
"""
# Common model context sizes
context_sizes = {
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-3-5-sonnet": 200000,
"claude-3-5-haiku": 200000,
"claude-opus-4": 200000,
"gpt-4-turbo": 128000,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 16385,
"gemini-1.5-pro": 2000000,
"gemini-1.5-flash": 1000000,
"gemini-2.0-flash": 1000000,
"qwen-plus": 32000,
"qwen-turbo": 8000,
"deepseek-chat": 64000,
"deepseek-reasoner": 64000,
}
# Check exact match first
model_lower = model.lower()
if model_lower in context_sizes:
return context_sizes[model_lower]
# Check prefix match
for model_name, size in context_sizes.items():
if model_lower.startswith(model_name):
return size
# Default fallback
return 8192
def create_budget_for_model(
self,
model: str,
custom_allocations: dict[str, float] | None = None,
) -> TokenBudget:
"""
Create a budget based on model's context window.
Args:
model: Model name
custom_allocations: Optional custom allocation percentages
Returns:
TokenBudget sized for the model
"""
context_size = self.get_model_context_size(model)
return self.create_budget(context_size, custom_allocations)

View File

@@ -0,0 +1,285 @@
"""
Token Calculator for Context Management.
Provides token counting with caching and fallback estimation.
Integrates with LLM Gateway for accurate counts.
"""
import hashlib
import logging
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
class TokenCounterProtocol(Protocol):
"""Protocol for token counting implementations."""
async def count_tokens(
self,
text: str,
model: str | None = None,
) -> int:
"""Count tokens in text."""
...
class TokenCalculator:
"""
Token calculator with LLM Gateway integration.
Features:
- In-memory caching for repeated text
- Fallback to character-based estimation
- Model-specific counting when possible
The calculator uses the LLM Gateway's count_tokens tool
for accurate counting, with a local cache to avoid
repeated calls for the same content.
"""
# Default characters per token ratio for estimation
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
# Model-specific ratios (more accurate estimation)
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
"claude": 3.5,
"gpt-4": 4.0,
"gpt-3.5": 4.0,
"gemini": 4.0,
}
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
project_id: str = "system",
agent_id: str = "context-engine",
cache_enabled: bool = True,
cache_max_size: int = 10000,
) -> None:
"""
Initialize token calculator.
Args:
mcp_manager: MCP client manager for LLM Gateway calls
project_id: Project ID for LLM Gateway calls
agent_id: Agent ID for LLM Gateway calls
cache_enabled: Whether to enable in-memory caching
cache_max_size: Maximum cache entries
"""
self._mcp = mcp_manager
self._project_id = project_id
self._agent_id = agent_id
self._cache_enabled = cache_enabled
self._cache_max_size = cache_max_size
# In-memory cache: hash(model:text) -> token_count
self._cache: dict[str, int] = {}
self._cache_hits = 0
self._cache_misses = 0
def _get_cache_key(self, text: str, model: str | None) -> str:
"""Generate cache key from text and model."""
# Use hash for efficient storage
content = f"{model or 'default'}:{text}"
return hashlib.sha256(content.encode()).hexdigest()[:32]
def _check_cache(self, cache_key: str) -> int | None:
"""Check cache for existing count."""
if not self._cache_enabled:
return None
if cache_key in self._cache:
self._cache_hits += 1
return self._cache[cache_key]
self._cache_misses += 1
return None
def _store_cache(self, cache_key: str, count: int) -> None:
"""Store count in cache."""
if not self._cache_enabled:
return
# Simple LRU-like eviction: remove oldest entries when full
if len(self._cache) >= self._cache_max_size:
# Remove first 10% of entries
entries_to_remove = self._cache_max_size // 10
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
for key in keys_to_remove:
del self._cache[key]
self._cache[cache_key] = count
def estimate_tokens(self, text: str, model: str | None = None) -> int:
"""
Estimate token count based on character count.
This is a fast fallback when LLM Gateway is unavailable.
Args:
text: Text to count
model: Optional model for more accurate ratio
Returns:
Estimated token count
"""
if not text:
return 0
# Get model-specific ratio
ratio = self.DEFAULT_CHARS_PER_TOKEN
if model:
model_lower = model.lower()
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
if model_prefix in model_lower:
ratio = model_ratio
break
return max(1, int(len(text) / ratio))
async def count_tokens(
self,
text: str,
model: str | None = None,
) -> int:
"""
Count tokens in text.
Uses LLM Gateway for accurate counts with fallback to estimation.
Args:
text: Text to count
model: Optional model for accurate counting
Returns:
Token count
"""
if not text:
return 0
# Check cache first
cache_key = self._get_cache_key(text, model)
cached = self._check_cache(cache_key)
if cached is not None:
return cached
# Try LLM Gateway
if self._mcp is not None:
try:
result = await self._mcp.call_tool(
server="llm-gateway",
tool="count_tokens",
args={
"project_id": self._project_id,
"agent_id": self._agent_id,
"text": text,
"model": model,
},
)
# Parse result
if result.success and result.data:
count = self._parse_token_count(result.data)
if count is not None:
self._store_cache(cache_key, count)
return count
except Exception as e:
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
# Fallback to estimation
count = self.estimate_tokens(text, model)
self._store_cache(cache_key, count)
return count
def _parse_token_count(self, data: Any) -> int | None:
"""Parse token count from LLM Gateway response."""
if isinstance(data, dict):
if "token_count" in data:
return int(data["token_count"])
if "tokens" in data:
return int(data["tokens"])
if "count" in data:
return int(data["count"])
if isinstance(data, int):
return data
if isinstance(data, str):
# Try to parse from text content
try:
# Handle {"token_count": 123} or just "123"
import json
parsed = json.loads(data)
if isinstance(parsed, dict) and "token_count" in parsed:
return int(parsed["token_count"])
if isinstance(parsed, int):
return parsed
except (json.JSONDecodeError, ValueError):
# Try direct int conversion
try:
return int(data)
except ValueError:
pass
return None
async def count_tokens_batch(
self,
texts: list[str],
model: str | None = None,
) -> list[int]:
"""
Count tokens for multiple texts.
Efficient batch counting with caching and parallel execution.
Args:
texts: List of texts to count
model: Optional model for accurate counting
Returns:
List of token counts (same order as input)
"""
import asyncio
if not texts:
return []
# Execute all token counts in parallel for better performance
tasks = [self.count_tokens(text, model) for text in texts]
return await asyncio.gather(*tasks)
def clear_cache(self) -> None:
"""Clear the token count cache."""
self._cache.clear()
self._cache_hits = 0
self._cache_misses = 0
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
total = self._cache_hits + self._cache_misses
hit_rate = self._cache_hits / total if total > 0 else 0.0
return {
"enabled": self._cache_enabled,
"size": len(self._cache),
"max_size": self._cache_max_size,
"hits": self._cache_hits,
"misses": self._cache_misses,
"hit_rate": round(hit_rate, 3),
}
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""
Set the MCP manager (for lazy initialization).
Args:
mcp_manager: MCP client manager instance
"""
self._mcp = mcp_manager

View File

@@ -3,3 +3,9 @@ Context Cache Module.
Provides Redis-based caching for assembled contexts. Provides Redis-based caching for assembled contexts.
""" """
from .context_cache import ContextCache
__all__ = [
"ContextCache",
]

View File

@@ -0,0 +1,422 @@
"""
Context Cache Implementation.
Provides Redis-based caching for context operations including
assembled contexts, token counts, and scoring results.
"""
import hashlib
import json
import logging
from typing import TYPE_CHECKING, Any
from ..config import ContextSettings, get_context_settings
from ..exceptions import CacheError
from ..types import AssembledContext, BaseContext
if TYPE_CHECKING:
from redis.asyncio import Redis
logger = logging.getLogger(__name__)
class ContextCache:
"""
Redis-based caching for context operations.
Provides caching for:
- Assembled contexts (fingerprint-based)
- Token counts (content hash-based)
- Scoring results (context + query hash-based)
Cache keys use a hierarchical structure:
- ctx:assembled:{fingerprint}
- ctx:tokens:{model}:{content_hash}
- ctx:score:{scorer}:{context_hash}:{query_hash}
"""
def __init__(
self,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize the context cache.
Args:
redis: Redis connection (optional for testing)
settings: Cache settings
"""
self._redis = redis
self._settings = settings or get_context_settings()
self._prefix = self._settings.cache_prefix
self._ttl = self._settings.cache_ttl_seconds
# In-memory fallback cache when Redis unavailable
self._memory_cache: dict[str, tuple[str, float]] = {}
self._max_memory_items = self._settings.cache_memory_max_items
def set_redis(self, redis: "Redis") -> None:
"""Set Redis connection."""
self._redis = redis
@property
def is_enabled(self) -> bool:
"""Check if caching is enabled and available."""
return self._settings.cache_enabled and self._redis is not None
def _cache_key(self, *parts: str) -> str:
"""
Build a cache key from parts.
Args:
*parts: Key components
Returns:
Colon-separated cache key
"""
return f"{self._prefix}:{':'.join(parts)}"
@staticmethod
def _hash_content(content: str) -> str:
"""
Compute hash of content for cache key.
Args:
content: Content to hash
Returns:
32-character hex hash
"""
return hashlib.sha256(content.encode()).hexdigest()[:32]
def compute_fingerprint(
self,
contexts: list[BaseContext],
query: str,
model: str,
) -> str:
"""
Compute a fingerprint for a context assembly request.
The fingerprint is based on:
- Context content hash and metadata (not full content for performance)
- Query string
- Target model
Args:
contexts: List of contexts
query: Query string
model: Model name
Returns:
32-character hex fingerprint
"""
# Build a deterministic representation using content hashes for performance
# This avoids JSON serializing potentially large content strings
context_data = []
for ctx in contexts:
context_data.append(
{
"type": ctx.get_type().value,
"content_hash": self._hash_content(
ctx.content
), # Hash instead of full content
"source": ctx.source,
"priority": ctx.priority, # Already an int
}
)
data = {
"contexts": context_data,
"query": query,
"model": model,
}
content = json.dumps(data, sort_keys=True)
return self._hash_content(content)
async def get_assembled(
self,
fingerprint: str,
) -> AssembledContext | None:
"""
Get cached assembled context by fingerprint.
Args:
fingerprint: Assembly fingerprint
Returns:
Cached AssembledContext or None if not found
"""
if not self.is_enabled:
return None
key = self._cache_key("assembled", fingerprint)
try:
data = await self._redis.get(key) # type: ignore
if data:
logger.debug(f"Cache hit for assembled context: {fingerprint}")
result = AssembledContext.from_json(data)
result.cache_hit = True
result.cache_key = fingerprint
return result
except Exception as e:
logger.warning(f"Cache get error: {e}")
raise CacheError(f"Failed to get assembled context: {e}") from e
return None
async def set_assembled(
self,
fingerprint: str,
context: AssembledContext,
ttl: int | None = None,
) -> None:
"""
Cache an assembled context.
Args:
fingerprint: Assembly fingerprint
context: Assembled context to cache
ttl: Optional TTL override in seconds
"""
if not self.is_enabled:
return
key = self._cache_key("assembled", fingerprint)
expire = ttl or self._ttl
try:
await self._redis.setex(key, expire, context.to_json()) # type: ignore
logger.debug(f"Cached assembled context: {fingerprint}")
except Exception as e:
logger.warning(f"Cache set error: {e}")
raise CacheError(f"Failed to cache assembled context: {e}") from e
async def get_token_count(
self,
content: str,
model: str | None = None,
) -> int | None:
"""
Get cached token count.
Args:
content: Content to look up
model: Model name for model-specific tokenization
Returns:
Cached token count or None if not found
"""
model_key = model or "default"
content_hash = self._hash_content(content)
key = self._cache_key("tokens", model_key, content_hash)
# Try in-memory first
if key in self._memory_cache:
return int(self._memory_cache[key][0])
if not self.is_enabled:
return None
try:
data = await self._redis.get(key) # type: ignore
if data:
count = int(data)
# Store in memory for faster subsequent access
self._set_memory(key, str(count))
return count
except Exception as e:
logger.warning(f"Cache get error for tokens: {e}")
return None
async def set_token_count(
self,
content: str,
count: int,
model: str | None = None,
ttl: int | None = None,
) -> None:
"""
Cache a token count.
Args:
content: Content that was counted
count: Token count
model: Model name
ttl: Optional TTL override in seconds
"""
model_key = model or "default"
content_hash = self._hash_content(content)
key = self._cache_key("tokens", model_key, content_hash)
expire = ttl or self._ttl
# Always store in memory
self._set_memory(key, str(count))
if not self.is_enabled:
return
try:
await self._redis.setex(key, expire, str(count)) # type: ignore
except Exception as e:
logger.warning(f"Cache set error for tokens: {e}")
async def get_score(
self,
scorer_name: str,
context_id: str,
query: str,
) -> float | None:
"""
Get cached score.
Args:
scorer_name: Name of the scorer
context_id: Context identifier
query: Query string
Returns:
Cached score or None if not found
"""
query_hash = self._hash_content(query)[:16]
key = self._cache_key("score", scorer_name, context_id, query_hash)
# Try in-memory first
if key in self._memory_cache:
return float(self._memory_cache[key][0])
if not self.is_enabled:
return None
try:
data = await self._redis.get(key) # type: ignore
if data:
score = float(data)
self._set_memory(key, str(score))
return score
except Exception as e:
logger.warning(f"Cache get error for score: {e}")
return None
async def set_score(
self,
scorer_name: str,
context_id: str,
query: str,
score: float,
ttl: int | None = None,
) -> None:
"""
Cache a score.
Args:
scorer_name: Name of the scorer
context_id: Context identifier
query: Query string
score: Score value
ttl: Optional TTL override in seconds
"""
query_hash = self._hash_content(query)[:16]
key = self._cache_key("score", scorer_name, context_id, query_hash)
expire = ttl or self._ttl
# Always store in memory
self._set_memory(key, str(score))
if not self.is_enabled:
return
try:
await self._redis.setex(key, expire, str(score)) # type: ignore
except Exception as e:
logger.warning(f"Cache set error for score: {e}")
async def invalidate(self, pattern: str) -> int:
"""
Invalidate cache entries matching a pattern.
Args:
pattern: Key pattern (supports * wildcard)
Returns:
Number of keys deleted
"""
if not self.is_enabled:
return 0
full_pattern = self._cache_key(pattern)
deleted = 0
try:
async for key in self._redis.scan_iter(match=full_pattern): # type: ignore
await self._redis.delete(key) # type: ignore
deleted += 1
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
except Exception as e:
logger.warning(f"Cache invalidation error: {e}")
raise CacheError(f"Failed to invalidate cache: {e}") from e
return deleted
async def clear_all(self) -> int:
"""
Clear all context cache entries.
Returns:
Number of keys deleted
"""
self._memory_cache.clear()
return await self.invalidate("*")
def _set_memory(self, key: str, value: str) -> None:
"""
Set a value in the memory cache.
Uses LRU-style eviction when max items reached.
Args:
key: Cache key
value: Value to store
"""
import time
if len(self._memory_cache) >= self._max_memory_items:
# Evict oldest entries
sorted_keys = sorted(
self._memory_cache.keys(),
key=lambda k: self._memory_cache[k][1],
)
for k in sorted_keys[: len(sorted_keys) // 2]:
del self._memory_cache[k]
self._memory_cache[key] = (value, time.time())
async def get_stats(self) -> dict[str, Any]:
"""
Get cache statistics.
Returns:
Dictionary with cache stats
"""
stats = {
"enabled": self._settings.cache_enabled,
"redis_available": self._redis is not None,
"memory_items": len(self._memory_cache),
"ttl_seconds": self._ttl,
}
if self.is_enabled:
try:
# Get Redis info
info = await self._redis.info("memory") # type: ignore
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
except Exception as e:
logger.debug(f"Failed to get Redis stats: {e}")
return stats

View File

@@ -3,3 +3,11 @@ Context Compression Module.
Provides truncation and compression strategies. Provides truncation and compression strategies.
""" """
from .truncation import ContextCompressor, TruncationResult, TruncationStrategy
__all__ = [
"ContextCompressor",
"TruncationResult",
"TruncationStrategy",
]

View File

@@ -0,0 +1,418 @@
"""
Smart Truncation for Context Compression.
Provides intelligent truncation strategies to reduce context size
while preserving the most important information.
"""
import logging
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING
from ..config import ContextSettings, get_context_settings
from ..types import BaseContext, ContextType
if TYPE_CHECKING:
from ..budget import TokenBudget, TokenCalculator
logger = logging.getLogger(__name__)
@dataclass
class TruncationResult:
"""Result of truncation operation."""
original_tokens: int
truncated_tokens: int
content: str
truncated: bool
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
@property
def tokens_saved(self) -> int:
"""Calculate tokens saved by truncation."""
return self.original_tokens - self.truncated_tokens
class TruncationStrategy:
"""
Smart truncation strategies for context compression.
Strategies:
1. End truncation: Cut from end (for knowledge/docs)
2. Middle truncation: Keep start and end (for code)
3. Sentence-aware: Truncate at sentence boundaries
4. Semantic chunking: Keep most relevant chunks
"""
def __init__(
self,
calculator: "TokenCalculator | None" = None,
preserve_ratio_start: float | None = None,
min_content_length: int | None = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize truncation strategy.
Args:
calculator: Token calculator for accurate counting
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
min_content_length: Minimum characters to preserve (overrides settings)
settings: Context settings (uses global if None)
"""
self._settings = settings or get_context_settings()
self._calculator = calculator
# Use provided values or fall back to settings
self._preserve_ratio_start = (
preserve_ratio_start
if preserve_ratio_start is not None
else self._settings.truncation_preserve_ratio
)
self._min_content_length = (
min_content_length
if min_content_length is not None
else self._settings.truncation_min_content_length
)
@property
def truncation_marker(self) -> str:
"""Get truncation marker from settings."""
return self._settings.truncation_marker
def set_calculator(self, calculator: "TokenCalculator") -> None:
"""Set token calculator."""
self._calculator = calculator
async def truncate_to_tokens(
self,
content: str,
max_tokens: int,
strategy: str = "end",
model: str | None = None,
) -> TruncationResult:
"""
Truncate content to fit within token limit.
Args:
content: Content to truncate
max_tokens: Maximum tokens allowed
strategy: Truncation strategy ('end', 'middle', 'sentence')
model: Model for token counting
Returns:
TruncationResult with truncated content
"""
if not content:
return TruncationResult(
original_tokens=0,
truncated_tokens=0,
content="",
truncated=False,
truncation_ratio=0.0,
)
# Get original token count
original_tokens = await self._count_tokens(content, model)
if original_tokens <= max_tokens:
return TruncationResult(
original_tokens=original_tokens,
truncated_tokens=original_tokens,
content=content,
truncated=False,
truncation_ratio=0.0,
)
# Apply truncation strategy
if strategy == "middle":
truncated = await self._truncate_middle(content, max_tokens, model)
elif strategy == "sentence":
truncated = await self._truncate_sentence(content, max_tokens, model)
else: # "end"
truncated = await self._truncate_end(content, max_tokens, model)
truncated_tokens = await self._count_tokens(truncated, model)
return TruncationResult(
original_tokens=original_tokens,
truncated_tokens=truncated_tokens,
content=truncated,
truncated=True,
truncation_ratio=0.0
if original_tokens == 0
else 1 - (truncated_tokens / original_tokens),
)
async def _truncate_end(
self,
content: str,
max_tokens: int,
model: str | None = None,
) -> str:
"""
Truncate from end of content.
Simple but effective for most content types.
"""
# Binary search for optimal truncation point
marker_tokens = await self._count_tokens(self.truncation_marker, model)
available_tokens = max(0, max_tokens - marker_tokens)
# Edge case: if no tokens available for content, return just the marker
if available_tokens <= 0:
return self.truncation_marker
# Estimate characters per token (guard against division by zero)
content_tokens = await self._count_tokens(content, model)
if content_tokens == 0:
return content + self.truncation_marker
chars_per_token = len(content) / content_tokens
# Start with estimated position
estimated_chars = int(available_tokens * chars_per_token)
truncated = content[:estimated_chars]
# Refine with binary search
low, high = len(truncated) // 2, len(truncated)
best = truncated
for _ in range(5): # Max 5 iterations
mid = (low + high) // 2
candidate = content[:mid]
tokens = await self._count_tokens(candidate, model)
if tokens <= available_tokens:
best = candidate
low = mid + 1
else:
high = mid - 1
return best + self.truncation_marker
async def _truncate_middle(
self,
content: str,
max_tokens: int,
model: str | None = None,
) -> str:
"""
Truncate from middle, keeping start and end.
Good for code or content where context at boundaries matters.
"""
marker_tokens = await self._count_tokens(self.truncation_marker, model)
available_tokens = max_tokens - marker_tokens
# Split between start and end
start_tokens = int(available_tokens * self._preserve_ratio_start)
end_tokens = available_tokens - start_tokens
# Get start portion
start_content = await self._get_content_for_tokens(
content, start_tokens, from_start=True, model=model
)
# Get end portion
end_content = await self._get_content_for_tokens(
content, end_tokens, from_start=False, model=model
)
return start_content + self.truncation_marker + end_content
async def _truncate_sentence(
self,
content: str,
max_tokens: int,
model: str | None = None,
) -> str:
"""
Truncate at sentence boundaries.
Produces cleaner output by not cutting mid-sentence.
"""
# Split into sentences
sentences = re.split(r"(?<=[.!?])\s+", content)
result: list[str] = []
total_tokens = 0
marker_tokens = await self._count_tokens(self.truncation_marker, model)
available = max_tokens - marker_tokens
for sentence in sentences:
sentence_tokens = await self._count_tokens(sentence, model)
if total_tokens + sentence_tokens <= available:
result.append(sentence)
total_tokens += sentence_tokens
else:
break
if len(result) < len(sentences):
return " ".join(result) + self.truncation_marker
return " ".join(result)
async def _get_content_for_tokens(
self,
content: str,
target_tokens: int,
from_start: bool = True,
model: str | None = None,
) -> str:
"""Get portion of content fitting within token limit."""
if target_tokens <= 0:
return ""
current_tokens = await self._count_tokens(content, model)
if current_tokens <= target_tokens:
return content
# Estimate characters (guard against division by zero)
if current_tokens == 0:
return content
chars_per_token = len(content) / current_tokens
estimated_chars = int(target_tokens * chars_per_token)
if from_start:
return content[:estimated_chars]
else:
return content[-estimated_chars:]
async def _count_tokens(self, text: str, model: str | None = None) -> int:
"""Count tokens using calculator or estimation."""
if self._calculator is not None:
return await self._calculator.count_tokens(text, model)
# Fallback estimation
return max(1, len(text) // 4)
class ContextCompressor:
"""
Compresses contexts to fit within budget constraints.
Uses truncation strategies to reduce context size while
preserving the most important information.
"""
def __init__(
self,
truncation: TruncationStrategy | None = None,
calculator: "TokenCalculator | None" = None,
) -> None:
"""
Initialize context compressor.
Args:
truncation: Truncation strategy to use
calculator: Token calculator for counting
"""
self._truncation = truncation or TruncationStrategy(calculator)
self._calculator = calculator
if calculator:
self._truncation.set_calculator(calculator)
def set_calculator(self, calculator: "TokenCalculator") -> None:
"""Set token calculator."""
self._calculator = calculator
self._truncation.set_calculator(calculator)
async def compress_context(
self,
context: BaseContext,
max_tokens: int,
model: str | None = None,
) -> BaseContext:
"""
Compress a single context to fit token limit.
Args:
context: Context to compress
max_tokens: Maximum tokens allowed
model: Model for token counting
Returns:
Compressed context (may be same object if no compression needed)
"""
current_tokens = context.token_count or await self._count_tokens(
context.content, model
)
if current_tokens <= max_tokens:
return context
# Choose strategy based on context type
strategy = self._get_strategy_for_type(context.get_type())
result = await self._truncation.truncate_to_tokens(
content=context.content,
max_tokens=max_tokens,
strategy=strategy,
model=model,
)
# Update context with truncated content
context.content = result.content
context.token_count = result.truncated_tokens
context.metadata["truncated"] = True
context.metadata["original_tokens"] = result.original_tokens
return context
async def compress_contexts(
self,
contexts: list[BaseContext],
budget: "TokenBudget",
model: str | None = None,
) -> list[BaseContext]:
"""
Compress multiple contexts to fit within budget.
Args:
contexts: Contexts to potentially compress
budget: Token budget constraints
model: Model for token counting
Returns:
List of contexts (compressed as needed)
"""
result: list[BaseContext] = []
for context in contexts:
context_type = context.get_type()
remaining = budget.remaining(context_type)
current_tokens = context.token_count or await self._count_tokens(
context.content, model
)
if current_tokens > remaining:
# Need to compress
compressed = await self.compress_context(context, remaining, model)
result.append(compressed)
logger.debug(
f"Compressed {context_type.value} context from "
f"{current_tokens} to {compressed.token_count} tokens"
)
else:
result.append(context)
return result
def _get_strategy_for_type(self, context_type: ContextType) -> str:
"""Get optimal truncation strategy for context type."""
strategies = {
ContextType.SYSTEM: "end", # Keep instructions at start
ContextType.TASK: "end", # Keep task description start
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
ContextType.CONVERSATION: "end", # Keep recent conversation
ContextType.TOOL: "middle", # Keep command and result summary
}
return strategies.get(context_type, "end")
async def _count_tokens(self, text: str, model: str | None = None) -> int:
"""Count tokens using calculator or estimation."""
if self._calculator is not None:
return await self._calculator.count_tokens(text, model)
return max(1, len(text) // 4)

View File

@@ -104,9 +104,21 @@ class ContextSettings(BaseSettings):
le=1.0, le=1.0,
description="Compress when budget usage exceeds this percentage", description="Compress when budget usage exceeds this percentage",
) )
truncation_suffix: str = Field( truncation_marker: str = Field(
default="... [truncated]", default="\n\n[...content truncated...]\n\n",
description="Suffix to add when truncating content", description="Marker text to insert where content was truncated",
)
truncation_preserve_ratio: float = Field(
default=0.7,
ge=0.1,
le=0.9,
description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)",
)
truncation_min_content_length: int = Field(
default=100,
ge=10,
le=1000,
description="Minimum content length in characters before truncation applies",
) )
summary_model_group: str = Field( summary_model_group: str = Field(
default="fast", default="fast",
@@ -128,6 +140,12 @@ class ContextSettings(BaseSettings):
default="ctx", default="ctx",
description="Redis key prefix for context cache", description="Redis key prefix for context cache",
) )
cache_memory_max_items: int = Field(
default=1000,
ge=100,
le=100000,
description="Maximum items in memory fallback cache when Redis unavailable",
)
# Performance settings # Performance settings
max_assembly_time_ms: int = Field( max_assembly_time_ms: int = Field(
@@ -165,6 +183,28 @@ class ContextSettings(BaseSettings):
description="Minimum relevance score for knowledge", description="Minimum relevance score for knowledge",
) )
# Relevance scoring settings
relevance_keyword_fallback_weight: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Maximum score for keyword-based fallback scoring (when semantic unavailable)",
)
relevance_semantic_max_chars: int = Field(
default=2000,
ge=100,
le=10000,
description="Maximum content length in chars for semantic similarity computation",
)
# Diversity/ranking settings
diversity_max_per_source: int = Field(
default=3,
ge=1,
le=20,
description="Maximum contexts from the same source in diversity reranking",
)
# Conversation history settings # Conversation history settings
conversation_max_turns: int = Field( conversation_max_turns: int = Field(
default=20, default=20,
@@ -253,11 +293,15 @@ class ContextSettings(BaseSettings):
"compression": { "compression": {
"threshold": self.compression_threshold, "threshold": self.compression_threshold,
"summary_model_group": self.summary_model_group, "summary_model_group": self.summary_model_group,
"truncation_marker": self.truncation_marker,
"truncation_preserve_ratio": self.truncation_preserve_ratio,
"truncation_min_content_length": self.truncation_min_content_length,
}, },
"cache": { "cache": {
"enabled": self.cache_enabled, "enabled": self.cache_enabled,
"ttl_seconds": self.cache_ttl_seconds, "ttl_seconds": self.cache_ttl_seconds,
"prefix": self.cache_prefix, "prefix": self.cache_prefix,
"memory_max_items": self.cache_memory_max_items,
}, },
"performance": { "performance": {
"max_assembly_time_ms": self.max_assembly_time_ms, "max_assembly_time_ms": self.max_assembly_time_ms,
@@ -269,6 +313,13 @@ class ContextSettings(BaseSettings):
"max_results": self.knowledge_max_results, "max_results": self.knowledge_max_results,
"min_score": self.knowledge_min_score, "min_score": self.knowledge_min_score,
}, },
"relevance": {
"keyword_fallback_weight": self.relevance_keyword_fallback_weight,
"semantic_max_chars": self.relevance_semantic_max_chars,
},
"diversity": {
"max_per_source": self.diversity_max_per_source,
},
"conversation": { "conversation": {
"max_turns": self.conversation_max_turns, "max_turns": self.conversation_max_turns,
"recent_priority": self.conversation_recent_priority, "recent_priority": self.conversation_recent_priority,

View File

@@ -0,0 +1,482 @@
"""
Context Management Engine.
Main orchestration layer for context assembly and optimization.
Provides a high-level API for assembling optimized context for LLM requests.
"""
import logging
from typing import TYPE_CHECKING, Any
from .assembly import ContextPipeline
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
from .cache import ContextCache
from .compression import ContextCompressor
from .config import ContextSettings, get_context_settings
from .prioritization import ContextRanker
from .scoring import CompositeScorer
from .types import (
AssembledContext,
BaseContext,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)
if TYPE_CHECKING:
from redis.asyncio import Redis
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
class ContextEngine:
"""
Main context management engine.
Provides high-level API for context assembly and optimization.
Integrates all components: scoring, ranking, compression, formatting, and caching.
Usage:
engine = ContextEngine(mcp_manager=mcp, redis=redis)
# Assemble context for an LLM request
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="implement user authentication",
model="claude-3-sonnet",
system_prompt="You are an expert developer.",
knowledge_query="authentication best practices",
)
# Use the assembled context
print(result.content)
print(f"Tokens: {result.total_tokens}")
"""
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize the context engine.
Args:
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
redis: Redis connection for caching
settings: Context settings
"""
self._mcp = mcp_manager
self._settings = settings or get_context_settings()
# Initialize components
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
self._compressor = ContextCompressor(calculator=self._calculator)
self._allocator = BudgetAllocator(self._settings)
self._cache = ContextCache(redis=redis, settings=self._settings)
# Pipeline for assembly
self._pipeline = ContextPipeline(
mcp_manager=mcp_manager,
settings=self._settings,
calculator=self._calculator,
scorer=self._scorer,
ranker=self._ranker,
compressor=self._compressor,
)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""
Set MCP manager for all components.
Args:
mcp_manager: MCP client manager
"""
self._mcp = mcp_manager
self._calculator.set_mcp_manager(mcp_manager)
self._scorer.set_mcp_manager(mcp_manager)
self._pipeline.set_mcp_manager(mcp_manager)
def set_redis(self, redis: "Redis") -> None:
"""
Set Redis connection for caching.
Args:
redis: Redis connection
"""
self._cache.set_redis(redis)
async def assemble_context(
self,
project_id: str,
agent_id: str,
query: str,
model: str,
max_tokens: int | None = None,
system_prompt: str | None = None,
task_description: str | None = None,
knowledge_query: str | None = None,
knowledge_limit: int = 10,
conversation_history: list[dict[str, str]] | None = None,
tool_results: list[dict[str, Any]] | None = None,
custom_contexts: list[BaseContext] | None = None,
custom_budget: TokenBudget | None = None,
compress: bool = True,
format_output: bool = True,
use_cache: bool = True,
) -> AssembledContext:
"""
Assemble optimized context for an LLM request.
This is the main entry point for context management.
It gathers context from various sources, scores and ranks them,
compresses if needed, and formats for the target model.
Args:
project_id: Project identifier
agent_id: Agent identifier
query: User's query or current request
model: Target model name
max_tokens: Maximum context tokens (uses model default if None)
system_prompt: System prompt/instructions
task_description: Current task description
knowledge_query: Query for knowledge base search
knowledge_limit: Max number of knowledge results
conversation_history: List of {"role": str, "content": str}
tool_results: List of tool results to include
custom_contexts: Additional custom contexts
custom_budget: Custom token budget
compress: Whether to apply compression
format_output: Whether to format for the model
use_cache: Whether to use caching
Returns:
AssembledContext with optimized content
Raises:
AssemblyTimeoutError: If assembly exceeds timeout
BudgetExceededError: If context exceeds budget
"""
# Gather all contexts
contexts: list[BaseContext] = []
# 1. System context
if system_prompt:
contexts.append(
SystemContext(
content=system_prompt,
source="system_prompt",
)
)
# 2. Task context
if task_description:
contexts.append(
TaskContext(
content=task_description,
source=f"task:{project_id}:{agent_id}",
)
)
# 3. Knowledge context from Knowledge Base
if knowledge_query and self._mcp:
knowledge_contexts = await self._fetch_knowledge(
project_id=project_id,
agent_id=agent_id,
query=knowledge_query,
limit=knowledge_limit,
)
contexts.extend(knowledge_contexts)
# 4. Conversation history
if conversation_history:
contexts.extend(self._convert_conversation(conversation_history))
# 5. Tool results
if tool_results:
contexts.extend(self._convert_tool_results(tool_results))
# 6. Custom contexts
if custom_contexts:
contexts.extend(custom_contexts)
# Check cache if enabled
fingerprint: str | None = None
if use_cache and self._cache.is_enabled:
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
cached = await self._cache.get_assembled(fingerprint)
if cached:
logger.debug(f"Cache hit for context assembly: {fingerprint}")
return cached
# Run assembly pipeline
result = await self._pipeline.assemble(
contexts=contexts,
query=query,
model=model,
max_tokens=max_tokens,
custom_budget=custom_budget,
compress=compress,
format_output=format_output,
)
# Cache result if enabled (reuse fingerprint computed above)
if use_cache and self._cache.is_enabled and fingerprint is not None:
await self._cache.set_assembled(fingerprint, result)
return result
async def _fetch_knowledge(
self,
project_id: str,
agent_id: str,
query: str,
limit: int = 10,
) -> list[KnowledgeContext]:
"""
Fetch relevant knowledge from Knowledge Base via MCP.
Args:
project_id: Project identifier
agent_id: Agent identifier
query: Search query
limit: Maximum results
Returns:
List of KnowledgeContext instances
"""
if not self._mcp:
return []
try:
result = await self._mcp.call_tool(
"knowledge-base",
"search_knowledge",
{
"project_id": project_id,
"agent_id": agent_id,
"query": query,
"search_type": "hybrid",
"limit": limit,
},
)
# Check both ToolResult.success AND response success
if not result.success:
logger.warning(f"Knowledge search failed: {result.error}")
return []
if not isinstance(result.data, dict) or not result.data.get(
"success", True
):
logger.warning("Knowledge search returned unsuccessful response")
return []
contexts = []
results = result.data.get("results", [])
for chunk in results:
contexts.append(
KnowledgeContext(
content=chunk.get("content", ""),
source=chunk.get("source_path", "unknown"),
relevance_score=chunk.get("score", 0.0),
metadata={
"chunk_id": chunk.get(
"id"
), # Server returns 'id' not 'chunk_id'
"document_id": chunk.get("document_id"),
},
)
)
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
return contexts
except Exception as e:
logger.warning(f"Failed to fetch knowledge: {e}")
return []
def _convert_conversation(
self,
history: list[dict[str, str]],
) -> list[ConversationContext]:
"""
Convert conversation history to ConversationContext instances.
Args:
history: List of {"role": str, "content": str}
Returns:
List of ConversationContext instances
"""
contexts = []
for i, turn in enumerate(history):
role_str = turn.get("role", "user").lower()
role = (
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
)
contexts.append(
ConversationContext(
content=turn.get("content", ""),
source=f"conversation:{i}",
role=role,
metadata={"role": role_str, "turn": i},
)
)
return contexts
def _convert_tool_results(
self,
results: list[dict[str, Any]],
) -> list[ToolContext]:
"""
Convert tool results to ToolContext instances.
Args:
results: List of tool result dictionaries
Returns:
List of ToolContext instances
"""
contexts = []
for result in results:
tool_name = result.get("tool_name", "unknown")
content = result.get("content", result.get("result", ""))
# Handle dict content
if isinstance(content, dict):
import json
content = json.dumps(content, indent=2)
contexts.append(
ToolContext(
content=str(content),
source=f"tool:{tool_name}",
metadata={
"tool_name": tool_name,
"status": result.get("status", "success"),
},
)
)
return contexts
async def get_budget_for_model(
self,
model: str,
max_tokens: int | None = None,
) -> TokenBudget:
"""
Get the token budget for a specific model.
Args:
model: Model name
max_tokens: Optional max tokens override
Returns:
TokenBudget instance
"""
if max_tokens:
return self._allocator.create_budget(max_tokens)
return self._allocator.create_budget_for_model(model)
async def count_tokens(
self,
content: str,
model: str | None = None,
) -> int:
"""
Count tokens in content.
Args:
content: Content to count
model: Model for model-specific tokenization
Returns:
Token count
"""
# Check cache first
cached = await self._cache.get_token_count(content, model)
if cached is not None:
return cached
count = await self._calculator.count_tokens(content, model)
# Cache the result
await self._cache.set_token_count(content, count, model)
return count
async def invalidate_cache(
self,
project_id: str | None = None,
pattern: str | None = None,
) -> int:
"""
Invalidate cache entries.
Args:
project_id: Invalidate all cache for a project
pattern: Custom pattern to match
Returns:
Number of entries invalidated
"""
if pattern:
return await self._cache.invalidate(pattern)
elif project_id:
return await self._cache.invalidate(f"*{project_id}*")
else:
return await self._cache.clear_all()
async def get_stats(self) -> dict[str, Any]:
"""
Get engine statistics.
Returns:
Dictionary with engine stats
"""
return {
"cache": await self._cache.get_stats(),
"settings": {
"compression_threshold": self._settings.compression_threshold,
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
"cache_enabled": self._settings.cache_enabled,
},
}
# Convenience factory function
def create_context_engine(
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
) -> ContextEngine:
"""
Create a context engine instance.
Args:
mcp_manager: MCP client manager
redis: Redis connection
settings: Context settings
Returns:
Configured ContextEngine instance
"""
return ContextEngine(
mcp_manager=mcp_manager,
redis=redis,
settings=settings,
)

View File

@@ -61,7 +61,7 @@ class BudgetExceededError(ContextError):
requested: Tokens requested requested: Tokens requested
context_type: Type of context that exceeded budget context_type: Type of context that exceeded budget
""" """
details = { details: dict[str, Any] = {
"allocated": allocated, "allocated": allocated,
"requested": requested, "requested": requested,
"overage": requested - allocated, "overage": requested - allocated,
@@ -170,7 +170,7 @@ class AssemblyTimeoutError(ContextError):
elapsed_ms: Actual elapsed time in milliseconds elapsed_ms: Actual elapsed time in milliseconds
stage: Pipeline stage where timeout occurred stage: Pipeline stage where timeout occurred
""" """
details = { details: dict[str, Any] = {
"timeout_ms": timeout_ms, "timeout_ms": timeout_ms,
"elapsed_ms": round(elapsed_ms, 2), "elapsed_ms": round(elapsed_ms, 2),
} }

View File

@@ -3,3 +3,10 @@ Context Prioritization Module.
Provides context ranking and selection. Provides context ranking and selection.
""" """
from .ranker import ContextRanker, RankingResult
__all__ = [
"ContextRanker",
"RankingResult",
]

View File

@@ -0,0 +1,313 @@
"""
Context Ranker for Context Management.
Ranks and selects contexts based on scores and budget constraints.
"""
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from ..budget import TokenBudget, TokenCalculator
from ..config import ContextSettings, get_context_settings
from ..scoring.composite import CompositeScorer, ScoredContext
from ..types import BaseContext, ContextPriority
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
@dataclass
class RankingResult:
"""Result of context ranking and selection."""
selected: list[ScoredContext]
excluded: list[ScoredContext]
total_tokens: int
selection_stats: dict[str, Any] = field(default_factory=dict)
@property
def selected_contexts(self) -> list[BaseContext]:
"""Get just the context objects (not scored wrappers)."""
return [s.context for s in self.selected]
class ContextRanker:
"""
Ranks and selects contexts within budget constraints.
Uses greedy selection to maximize total score
while respecting token budgets per context type.
"""
def __init__(
self,
scorer: CompositeScorer | None = None,
calculator: TokenCalculator | None = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize context ranker.
Args:
scorer: Composite scorer for scoring contexts
calculator: Token calculator for counting tokens
settings: Context settings (uses global if None)
"""
self._settings = settings or get_context_settings()
self._scorer = scorer or CompositeScorer()
self._calculator = calculator or TokenCalculator()
def set_scorer(self, scorer: CompositeScorer) -> None:
"""Set the scorer."""
self._scorer = scorer
def set_calculator(self, calculator: TokenCalculator) -> None:
"""Set the token calculator."""
self._calculator = calculator
async def rank(
self,
contexts: list[BaseContext],
query: str,
budget: TokenBudget,
model: str | None = None,
ensure_required: bool = True,
**kwargs: Any,
) -> RankingResult:
"""
Rank and select contexts within budget.
Args:
contexts: Contexts to rank
query: Query to rank against
budget: Token budget constraints
model: Model for token counting
ensure_required: If True, always include CRITICAL priority contexts
**kwargs: Additional scoring parameters
Returns:
RankingResult with selected and excluded contexts
"""
if not contexts:
return RankingResult(
selected=[],
excluded=[],
total_tokens=0,
selection_stats={"total_contexts": 0},
)
# 1. Ensure all contexts have token counts
await self._ensure_token_counts(contexts, model)
# 2. Score all contexts
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
# 3. Separate required (CRITICAL priority) from optional
required: list[ScoredContext] = []
optional: list[ScoredContext] = []
if ensure_required:
for sc in scored_contexts:
# CRITICAL priority (150) contexts are always included
if sc.context.priority >= ContextPriority.CRITICAL.value:
required.append(sc)
else:
optional.append(sc)
else:
optional = list(scored_contexts)
# 4. Sort optional by score (highest first)
optional.sort(reverse=True)
# 5. Greedy selection
selected: list[ScoredContext] = []
excluded: list[ScoredContext] = []
total_tokens = 0
# First, try to fit required contexts
for sc in required:
token_count = sc.context.token_count or 0
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
budget.allocate(context_type, token_count)
selected.append(sc)
total_tokens += token_count
else:
# Force-fit CRITICAL contexts if needed
budget.allocate(context_type, token_count, force=True)
selected.append(sc)
total_tokens += token_count
logger.warning(
f"Force-fitted CRITICAL context: {sc.context.source} "
f"({token_count} tokens)"
)
# Then, greedily add optional contexts
for sc in optional:
token_count = sc.context.token_count or 0
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
budget.allocate(context_type, token_count)
selected.append(sc)
total_tokens += token_count
else:
excluded.append(sc)
# Build stats
stats = {
"total_contexts": len(contexts),
"required_count": len(required),
"selected_count": len(selected),
"excluded_count": len(excluded),
"total_tokens": total_tokens,
"by_type": self._count_by_type(selected),
}
return RankingResult(
selected=selected,
excluded=excluded,
total_tokens=total_tokens,
selection_stats=stats,
)
async def rank_simple(
self,
contexts: list[BaseContext],
query: str,
max_tokens: int,
model: str | None = None,
**kwargs: Any,
) -> list[BaseContext]:
"""
Simple ranking without budget per type.
Selects top contexts by score until max tokens reached.
Args:
contexts: Contexts to rank
query: Query to rank against
max_tokens: Maximum total tokens
model: Model for token counting
**kwargs: Additional scoring parameters
Returns:
Selected contexts (in score order)
"""
if not contexts:
return []
# Ensure token counts
await self._ensure_token_counts(contexts, model)
# Score all contexts
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
# Sort by score (highest first)
scored_contexts.sort(reverse=True)
# Greedy selection
selected: list[BaseContext] = []
total_tokens = 0
for sc in scored_contexts:
token_count = sc.context.token_count or 0
if total_tokens + token_count <= max_tokens:
selected.append(sc.context)
total_tokens += token_count
return selected
async def _ensure_token_counts(
self,
contexts: list[BaseContext],
model: str | None = None,
) -> None:
"""
Ensure all contexts have token counts.
Counts tokens in parallel for contexts that don't have counts.
Args:
contexts: Contexts to check
model: Model for token counting
"""
import asyncio
# Find contexts needing counts
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
if not contexts_needing_counts:
return
# Count all in parallel
tasks = [
self._calculator.count_tokens(ctx.content, model)
for ctx in contexts_needing_counts
]
counts = await asyncio.gather(*tasks)
# Assign counts back
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
ctx.token_count = count
def _count_by_type(
self, scored_contexts: list[ScoredContext]
) -> dict[str, dict[str, int]]:
"""Count selected contexts by type."""
by_type: dict[str, dict[str, int]] = {}
for sc in scored_contexts:
type_name = sc.context.get_type().value
if type_name not in by_type:
by_type[type_name] = {"count": 0, "tokens": 0}
by_type[type_name]["count"] += 1
by_type[type_name]["tokens"] += sc.context.token_count or 0
return by_type
async def rerank_for_diversity(
self,
scored_contexts: list[ScoredContext],
max_per_source: int | None = None,
) -> list[ScoredContext]:
"""
Rerank to ensure source diversity.
Prevents too many items from the same source.
Args:
scored_contexts: Already scored contexts
max_per_source: Maximum items per source (uses settings if None)
Returns:
Reranked contexts
"""
# Use provided value or fall back to settings
effective_max = (
max_per_source
if max_per_source is not None
else self._settings.diversity_max_per_source
)
source_counts: dict[str, int] = {}
result: list[ScoredContext] = []
deferred: list[ScoredContext] = []
for sc in scored_contexts:
source = sc.context.source
current_count = source_counts.get(source, 0)
if current_count < effective_max:
result.append(sc)
source_counts[source] = current_count + 1
else:
deferred.append(sc)
# Add deferred items at the end
result.extend(deferred)
return result

View File

@@ -1,5 +1,21 @@
""" """
Context Scoring Module. Context Scoring Module.
Provides relevance, recency, and composite scoring. Provides scoring strategies for context prioritization.
""" """
from .base import BaseScorer, ScorerProtocol
from .composite import CompositeScorer, ScoredContext
from .priority import PriorityScorer
from .recency import RecencyScorer
from .relevance import RelevanceScorer
__all__ = [
"BaseScorer",
"CompositeScorer",
"PriorityScorer",
"RecencyScorer",
"RelevanceScorer",
"ScoredContext",
"ScorerProtocol",
]

View File

@@ -0,0 +1,99 @@
"""
Base Scorer Protocol and Types.
Defines the interface for context scoring implementations.
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from ..types import BaseContext
if TYPE_CHECKING:
pass
@runtime_checkable
class ScorerProtocol(Protocol):
"""Protocol for context scorers."""
async def score(
self,
context: BaseContext,
query: str,
**kwargs: Any,
) -> float:
"""
Score a context item.
Args:
context: Context to score
query: Query to score against
**kwargs: Additional scoring parameters
Returns:
Score between 0.0 and 1.0
"""
...
class BaseScorer(ABC):
"""
Abstract base class for context scorers.
Provides common functionality and interface for
different scoring strategies.
"""
def __init__(self, weight: float = 1.0) -> None:
"""
Initialize scorer.
Args:
weight: Weight for this scorer in composite scoring
"""
self._weight = weight
@property
def weight(self) -> float:
"""Get scorer weight."""
return self._weight
@weight.setter
def weight(self, value: float) -> None:
"""Set scorer weight."""
if not 0.0 <= value <= 1.0:
raise ValueError("Weight must be between 0.0 and 1.0")
self._weight = value
@abstractmethod
async def score(
self,
context: BaseContext,
query: str,
**kwargs: Any,
) -> float:
"""
Score a context item.
Args:
context: Context to score
query: Query to score against
**kwargs: Additional scoring parameters
Returns:
Score between 0.0 and 1.0
"""
...
def normalize_score(self, score: float) -> float:
"""
Normalize score to [0.0, 1.0] range.
Args:
score: Raw score
Returns:
Normalized score
"""
return max(0.0, min(1.0, score))

View File

@@ -0,0 +1,314 @@
"""
Composite Scorer for Context Management.
Combines multiple scoring strategies with configurable weights.
"""
import asyncio
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary
from ..config import ContextSettings, get_context_settings
from ..types import BaseContext
from .priority import PriorityScorer
from .recency import RecencyScorer
from .relevance import RelevanceScorer
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
@dataclass
class ScoredContext:
"""Context with computed scores."""
context: BaseContext
composite_score: float
relevance_score: float = 0.0
recency_score: float = 0.0
priority_score: float = 0.0
def __lt__(self, other: "ScoredContext") -> bool:
"""Enable sorting by composite score."""
return self.composite_score < other.composite_score
def __gt__(self, other: "ScoredContext") -> bool:
"""Enable sorting by composite score."""
return self.composite_score > other.composite_score
class CompositeScorer:
"""
Combines multiple scoring strategies.
Weights:
- relevance: How well content matches the query
- recency: How recent the content is
- priority: Explicit priority assignments
"""
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
settings: ContextSettings | None = None,
relevance_weight: float | None = None,
recency_weight: float | None = None,
priority_weight: float | None = None,
) -> None:
"""
Initialize composite scorer.
Args:
mcp_manager: MCP manager for semantic scoring
settings: Context settings (uses default if None)
relevance_weight: Override relevance weight
recency_weight: Override recency weight
priority_weight: Override priority weight
"""
self._settings = settings or get_context_settings()
weights = self._settings.get_scoring_weights()
self._relevance_weight = (
relevance_weight if relevance_weight is not None else weights["relevance"]
)
self._recency_weight = (
recency_weight if recency_weight is not None else weights["recency"]
)
self._priority_weight = (
priority_weight if priority_weight is not None else weights["priority"]
)
# Initialize scorers
self._relevance_scorer = RelevanceScorer(
mcp_manager=mcp_manager,
weight=self._relevance_weight,
)
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
# Per-context locks to prevent race conditions during parallel scoring
# Uses WeakValueDictionary so locks are garbage collected when not in use
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = (
WeakValueDictionary()
)
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""Set MCP manager for semantic scoring."""
self._relevance_scorer.set_mcp_manager(mcp_manager)
@property
def weights(self) -> dict[str, float]:
"""Get current scoring weights."""
return {
"relevance": self._relevance_weight,
"recency": self._recency_weight,
"priority": self._priority_weight,
}
def update_weights(
self,
relevance: float | None = None,
recency: float | None = None,
priority: float | None = None,
) -> None:
"""
Update scoring weights.
Args:
relevance: New relevance weight
recency: New recency weight
priority: New priority weight
"""
if relevance is not None:
self._relevance_weight = max(0.0, min(1.0, relevance))
self._relevance_scorer.weight = self._relevance_weight
if recency is not None:
self._recency_weight = max(0.0, min(1.0, recency))
self._recency_scorer.weight = self._recency_weight
if priority is not None:
self._priority_weight = max(0.0, min(1.0, priority))
self._priority_scorer.weight = self._priority_weight
async def _get_context_lock(self, context_id: str) -> asyncio.Lock:
"""
Get or create a lock for a specific context.
Thread-safe access to per-context locks prevents race conditions
when the same context is scored concurrently.
Args:
context_id: The context ID to get a lock for
Returns:
asyncio.Lock for the context
"""
# Fast path: check if lock exists without acquiring main lock
if context_id in self._context_locks:
lock = self._context_locks.get(context_id)
if lock is not None:
return lock
# Slow path: create lock while holding main lock
async with self._locks_lock:
# Double-check after acquiring lock
if context_id in self._context_locks:
lock = self._context_locks.get(context_id)
if lock is not None:
return lock
# Create new lock
new_lock = asyncio.Lock()
self._context_locks[context_id] = new_lock
return new_lock
async def score(
self,
context: BaseContext,
query: str,
**kwargs: Any,
) -> float:
"""
Compute composite score for a context.
Args:
context: Context to score
query: Query to score against
**kwargs: Additional scoring parameters
Returns:
Composite score between 0.0 and 1.0
"""
scored = await self.score_with_details(context, query, **kwargs)
return scored.composite_score
async def score_with_details(
self,
context: BaseContext,
query: str,
**kwargs: Any,
) -> ScoredContext:
"""
Compute composite score with individual scores.
Uses per-context locking to prevent race conditions when the same
context is scored concurrently in parallel scoring operations.
Args:
context: Context to score
query: Query to score against
**kwargs: Additional scoring parameters
Returns:
ScoredContext with all scores
"""
# Get lock for this specific context to prevent race conditions
# within concurrent scoring operations for the same query
context_lock = await self._get_context_lock(context.id)
async with context_lock:
# Compute individual scores in parallel
# Note: We do NOT cache scores on the context because scores are
# query-dependent. Caching without considering the query would
# return incorrect scores for different queries.
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
recency_task = self._recency_scorer.score(context, query, **kwargs)
priority_task = self._priority_scorer.score(context, query, **kwargs)
relevance_score, recency_score, priority_score = await asyncio.gather(
relevance_task, recency_task, priority_task
)
# Compute weighted composite
total_weight = (
self._relevance_weight + self._recency_weight + self._priority_weight
)
if total_weight > 0:
composite = (
relevance_score * self._relevance_weight
+ recency_score * self._recency_weight
+ priority_score * self._priority_weight
) / total_weight
else:
composite = 0.0
return ScoredContext(
context=context,
composite_score=composite,
relevance_score=relevance_score,
recency_score=recency_score,
priority_score=priority_score,
)
async def score_batch(
self,
contexts: list[BaseContext],
query: str,
parallel: bool = True,
**kwargs: Any,
) -> list[ScoredContext]:
"""
Score multiple contexts.
Args:
contexts: Contexts to score
query: Query to score against
parallel: Whether to score in parallel
**kwargs: Additional scoring parameters
Returns:
List of ScoredContext (same order as input)
"""
if parallel:
tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts]
return await asyncio.gather(*tasks)
else:
results = []
for ctx in contexts:
scored = await self.score_with_details(ctx, query, **kwargs)
results.append(scored)
return results
async def rank(
self,
contexts: list[BaseContext],
query: str,
limit: int | None = None,
min_score: float = 0.0,
**kwargs: Any,
) -> list[ScoredContext]:
"""
Score and rank contexts.
Args:
contexts: Contexts to rank
query: Query to rank against
limit: Maximum number of results
min_score: Minimum score threshold
**kwargs: Additional scoring parameters
Returns:
Sorted list of ScoredContext (highest first)
"""
# Score all contexts
scored = await self.score_batch(contexts, query, **kwargs)
# Filter by minimum score
if min_score > 0:
scored = [s for s in scored if s.composite_score >= min_score]
# Sort by score (highest first)
scored.sort(reverse=True)
# Apply limit
if limit is not None:
scored = scored[:limit]
return scored

View File

@@ -0,0 +1,135 @@
"""
Priority Scorer for Context Management.
Scores context based on assigned priority levels.
"""
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import BaseScorer
class PriorityScorer(BaseScorer):
"""
Scores context based on priority levels.
Converts priority enum values to normalized scores.
Also applies type-based priority bonuses.
"""
# Default priority bonuses by context type
DEFAULT_TYPE_BONUSES: ClassVar[dict[ContextType, float]] = {
ContextType.SYSTEM: 0.2, # System prompts get a boost
ContextType.TASK: 0.15, # Current task is important
ContextType.TOOL: 0.1, # Recent tool results matter
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
}
def __init__(
self,
weight: float = 1.0,
type_bonuses: dict[ContextType, float] | None = None,
) -> None:
"""
Initialize priority scorer.
Args:
weight: Scorer weight for composite scoring
type_bonuses: Optional context-type priority bonuses
"""
super().__init__(weight)
self._type_bonuses = type_bonuses or self.DEFAULT_TYPE_BONUSES.copy()
async def score(
self,
context: BaseContext,
query: str,
**kwargs: Any,
) -> float:
"""
Score context based on priority.
Args:
context: Context to score
query: Query (not used for priority, kept for interface)
**kwargs: Additional parameters
Returns:
Priority score between 0.0 and 1.0
"""
# Get base priority score
priority_value = context.priority
base_score = self._priority_to_score(priority_value)
# Apply type bonus
context_type = context.get_type()
bonus = self._type_bonuses.get(context_type, 0.0)
return self.normalize_score(base_score + bonus)
def _priority_to_score(self, priority: int) -> float:
"""
Convert priority value to normalized score.
Priority values (from ContextPriority):
- CRITICAL (100) -> 1.0
- HIGH (80) -> 0.8
- NORMAL (50) -> 0.5
- LOW (20) -> 0.2
- MINIMAL (0) -> 0.0
Args:
priority: Priority value (0-100)
Returns:
Normalized score (0.0-1.0)
"""
# Clamp to valid range
clamped = max(0, min(100, priority))
return clamped / 100.0
def get_type_bonus(self, context_type: ContextType) -> float:
"""
Get priority bonus for a context type.
Args:
context_type: Context type
Returns:
Bonus value
"""
return self._type_bonuses.get(context_type, 0.0)
def set_type_bonus(self, context_type: ContextType, bonus: float) -> None:
"""
Set priority bonus for a context type.
Args:
context_type: Context type
bonus: Bonus value (0.0-1.0)
"""
if not 0.0 <= bonus <= 1.0:
raise ValueError("Bonus must be between 0.0 and 1.0")
self._type_bonuses[context_type] = bonus
async def score_batch(
self,
contexts: list[BaseContext],
query: str,
**kwargs: Any,
) -> list[float]:
"""
Score multiple contexts.
Args:
contexts: Contexts to score
query: Query (not used)
**kwargs: Additional parameters
Returns:
List of scores (same order as input)
"""
# Priority scoring is fast, no async needed
return [await self.score(ctx, query, **kwargs) for ctx in contexts]

View File

@@ -0,0 +1,141 @@
"""
Recency Scorer for Context Management.
Scores context based on how recent it is.
More recent content gets higher scores.
"""
import math
from datetime import UTC, datetime
from typing import Any
from ..types import BaseContext, ContextType
from .base import BaseScorer
class RecencyScorer(BaseScorer):
"""
Scores context based on recency.
Uses exponential decay to score content based on age.
More recent content scores higher.
"""
def __init__(
self,
weight: float = 1.0,
half_life_hours: float = 24.0,
type_half_lives: dict[ContextType, float] | None = None,
) -> None:
"""
Initialize recency scorer.
Args:
weight: Scorer weight for composite scoring
half_life_hours: Default hours until score decays to 0.5
type_half_lives: Optional context-type-specific half lives
"""
super().__init__(weight)
self._half_life_hours = half_life_hours
self._type_half_lives = type_half_lives or {}
# Set sensible defaults for context types
if ContextType.CONVERSATION not in self._type_half_lives:
self._type_half_lives[ContextType.CONVERSATION] = 1.0 # 1 hour
if ContextType.TOOL not in self._type_half_lives:
self._type_half_lives[ContextType.TOOL] = 0.5 # 30 minutes
if ContextType.KNOWLEDGE not in self._type_half_lives:
self._type_half_lives[ContextType.KNOWLEDGE] = 168.0 # 1 week
if ContextType.SYSTEM not in self._type_half_lives:
self._type_half_lives[ContextType.SYSTEM] = 720.0 # 30 days
if ContextType.TASK not in self._type_half_lives:
self._type_half_lives[ContextType.TASK] = 24.0 # 1 day
async def score(
self,
context: BaseContext,
query: str,
**kwargs: Any,
) -> float:
"""
Score context based on recency.
Args:
context: Context to score
query: Query (not used for recency, kept for interface)
**kwargs: Additional parameters
- reference_time: Time to measure recency from (default: now)
Returns:
Recency score between 0.0 and 1.0
"""
reference_time = kwargs.get("reference_time")
if reference_time is None:
reference_time = datetime.now(UTC)
elif reference_time.tzinfo is None:
reference_time = reference_time.replace(tzinfo=UTC)
# Ensure context timestamp is timezone-aware
context_time = context.timestamp
if context_time.tzinfo is None:
context_time = context_time.replace(tzinfo=UTC)
# Calculate age in hours
age = reference_time - context_time
age_hours = max(0, age.total_seconds() / 3600)
# Get half-life for this context type
context_type = context.get_type()
half_life = self._type_half_lives.get(context_type, self._half_life_hours)
# Exponential decay
decay_factor = math.exp(-math.log(2) * age_hours / half_life)
return self.normalize_score(decay_factor)
def get_half_life(self, context_type: ContextType) -> float:
"""
Get half-life for a context type.
Args:
context_type: Context type to get half-life for
Returns:
Half-life in hours
"""
return self._type_half_lives.get(context_type, self._half_life_hours)
def set_half_life(self, context_type: ContextType, hours: float) -> None:
"""
Set half-life for a context type.
Args:
context_type: Context type to set half-life for
hours: Half-life in hours
"""
if hours <= 0:
raise ValueError("Half-life must be positive")
self._type_half_lives[context_type] = hours
async def score_batch(
self,
contexts: list[BaseContext],
query: str,
**kwargs: Any,
) -> list[float]:
"""
Score multiple contexts.
Args:
contexts: Contexts to score
query: Query (not used)
**kwargs: Additional parameters
Returns:
List of scores (same order as input)
"""
scores = []
for context in contexts:
score = await self.score(context, query, **kwargs)
scores.append(score)
return scores

View File

@@ -0,0 +1,220 @@
"""
Relevance Scorer for Context Management.
Scores context based on semantic similarity to the query.
Uses Knowledge Base embeddings when available.
"""
import logging
import re
from typing import TYPE_CHECKING, Any
from ..config import ContextSettings, get_context_settings
from ..types import BaseContext, KnowledgeContext
from .base import BaseScorer
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
class RelevanceScorer(BaseScorer):
"""
Scores context based on relevance to query.
Uses multiple strategies:
1. Pre-computed scores (from RAG results)
2. MCP-based semantic similarity (via Knowledge Base)
3. Keyword matching fallback
"""
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
weight: float = 1.0,
keyword_fallback_weight: float | None = None,
semantic_max_chars: int | None = None,
settings: ContextSettings | None = None,
) -> None:
"""
Initialize relevance scorer.
Args:
mcp_manager: MCP manager for Knowledge Base calls
weight: Scorer weight for composite scoring
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
semantic_max_chars: Max content length for semantic similarity (overrides settings)
settings: Context settings (uses global if None)
"""
super().__init__(weight)
self._settings = settings or get_context_settings()
self._mcp = mcp_manager
# Use provided values or fall back to settings
self._keyword_fallback_weight = (
keyword_fallback_weight
if keyword_fallback_weight is not None
else self._settings.relevance_keyword_fallback_weight
)
self._semantic_max_chars = (
semantic_max_chars
if semantic_max_chars is not None
else self._settings.relevance_semantic_max_chars
)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""Set MCP manager for semantic scoring."""
self._mcp = mcp_manager
async def score(
self,
context: BaseContext,
query: str,
**kwargs: Any,
) -> float:
"""
Score context relevance to query.
Args:
context: Context to score
query: Query to score against
**kwargs: Additional parameters
Returns:
Relevance score between 0.0 and 1.0
"""
# 1. Check for pre-computed relevance score
if (
isinstance(context, KnowledgeContext)
and context.relevance_score is not None
):
return self.normalize_score(context.relevance_score)
# 2. Check metadata for score
if "relevance_score" in context.metadata:
return self.normalize_score(context.metadata["relevance_score"])
if "score" in context.metadata:
return self.normalize_score(context.metadata["score"])
# 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
# Note: This requires the knowledge-base MCP server to implement compute_similarity
if self._mcp is not None:
try:
score = await self._compute_semantic_similarity(context, query)
if score is not None:
return score
except Exception as e:
# Log at debug level since this is expected if compute_similarity
# tool is not implemented in the Knowledge Base server
logger.debug(
f"Semantic scoring unavailable, using keyword fallback: {e}"
)
# 4. Fall back to keyword matching
return self._compute_keyword_score(context, query)
async def _compute_semantic_similarity(
self,
context: BaseContext,
query: str,
) -> float | None:
"""
Compute semantic similarity using Knowledge Base embeddings.
Args:
context: Context to score
query: Query to compare
Returns:
Similarity score or None if unavailable
"""
if self._mcp is None:
return None
try:
# Use Knowledge Base's search capability to compute similarity
result = await self._mcp.call_tool(
server="knowledge-base",
tool="compute_similarity",
args={
"text1": query,
"text2": context.content[
: self._semantic_max_chars
], # Limit content length
},
)
if result.success and isinstance(result.data, dict):
similarity = result.data.get("similarity")
if similarity is not None:
return self.normalize_score(float(similarity))
except Exception as e:
logger.debug(f"Semantic similarity computation failed: {e}")
return None
def _compute_keyword_score(
self,
context: BaseContext,
query: str,
) -> float:
"""
Compute relevance score based on keyword matching.
Simple but fast fallback when semantic search is unavailable.
Args:
context: Context to score
query: Query to match
Returns:
Keyword-based relevance score
"""
if not query or not context.content:
return 0.0
# Extract keywords from query
query_lower = query.lower()
content_lower = context.content.lower()
# Simple word tokenization
query_words = set(re.findall(r"\b\w{3,}\b", query_lower))
content_words = set(re.findall(r"\b\w{3,}\b", content_lower))
if not query_words:
return 0.0
# Calculate overlap
common_words = query_words & content_words
overlap_ratio = len(common_words) / len(query_words)
# Apply fallback weight ceiling
return self.normalize_score(overlap_ratio * self._keyword_fallback_weight)
async def score_batch(
self,
contexts: list[BaseContext],
query: str,
**kwargs: Any,
) -> list[float]:
"""
Score multiple contexts in parallel.
Args:
contexts: Contexts to score
query: Query to score against
**kwargs: Additional parameters
Returns:
List of scores (same order as input)
"""
import asyncio
if not contexts:
return []
tasks = [self.score(context, query, **kwargs) for context in contexts]
return await asyncio.gather(*tasks)

View File

@@ -27,23 +27,17 @@ from .tool import (
) )
__all__ = [ __all__ = [
# Base types
"AssembledContext", "AssembledContext",
"BaseContext", "BaseContext",
"ContextPriority", "ContextPriority",
"ContextType", "ContextType",
# Conversation
"ConversationContext", "ConversationContext",
"MessageRole",
# Knowledge
"KnowledgeContext", "KnowledgeContext",
# System "MessageRole",
"SystemContext", "SystemContext",
# Task
"TaskComplexity", "TaskComplexity",
"TaskContext", "TaskContext",
"TaskStatus", "TaskStatus",
# Tool
"ToolContext", "ToolContext",
"ToolResultStatus", "ToolResultStatus",
] ]

View File

@@ -253,12 +253,19 @@ class AssembledContext:
# Main content # Main content
content: str content: str
token_count: int total_tokens: int
# Assembly metadata # Assembly metadata
contexts_included: int context_count: int
contexts_excluded: int = 0 excluded_count: int = 0
assembly_time_ms: float = 0.0 assembly_time_ms: float = 0.0
model: str = ""
# Included contexts (optional - for inspection)
contexts: list["BaseContext"] = field(default_factory=list)
# Additional metadata from assembly
metadata: dict[str, Any] = field(default_factory=dict)
# Budget tracking # Budget tracking
budget_total: int = 0 budget_total: int = 0
@@ -271,6 +278,22 @@ class AssembledContext:
cache_hit: bool = False cache_hit: bool = False
cache_key: str | None = None cache_key: str | None = None
# Aliases for backward compatibility
@property
def token_count(self) -> int:
"""Alias for total_tokens."""
return self.total_tokens
@property
def contexts_included(self) -> int:
"""Alias for context_count."""
return self.context_count
@property
def contexts_excluded(self) -> int:
"""Alias for excluded_count."""
return self.excluded_count
@property @property
def budget_utilization(self) -> float: def budget_utilization(self) -> float:
"""Get budget utilization percentage.""" """Get budget utilization percentage."""
@@ -282,10 +305,12 @@ class AssembledContext:
"""Convert to dictionary.""" """Convert to dictionary."""
return { return {
"content": self.content, "content": self.content,
"token_count": self.token_count, "total_tokens": self.total_tokens,
"contexts_included": self.contexts_included, "context_count": self.context_count,
"contexts_excluded": self.contexts_excluded, "excluded_count": self.excluded_count,
"assembly_time_ms": round(self.assembly_time_ms, 2), "assembly_time_ms": round(self.assembly_time_ms, 2),
"model": self.model,
"metadata": self.metadata,
"budget_total": self.budget_total, "budget_total": self.budget_total,
"budget_used": self.budget_used, "budget_used": self.budget_used,
"budget_utilization": round(self.budget_utilization, 3), "budget_utilization": round(self.budget_utilization, 3),
@@ -308,10 +333,12 @@ class AssembledContext:
data = json.loads(json_str) data = json.loads(json_str)
return cls( return cls(
content=data["content"], content=data["content"],
token_count=data["token_count"], total_tokens=data["total_tokens"],
contexts_included=data["contexts_included"], context_count=data["context_count"],
contexts_excluded=data.get("contexts_excluded", 0), excluded_count=data.get("excluded_count", 0),
assembly_time_ms=data.get("assembly_time_ms", 0.0), assembly_time_ms=data.get("assembly_time_ms", 0.0),
model=data.get("model", ""),
metadata=data.get("metadata", {}),
budget_total=data.get("budget_total", 0), budget_total=data.get("budget_total", 0),
budget_used=data.get("budget_used", 0), budget_used=data.get("budget_used", 0),
by_type=data.get("by_type", {}), by_type=data.get("by_type", {}),

View File

@@ -120,7 +120,16 @@ class KnowledgeContext(BaseContext):
def is_code(self) -> bool: def is_code(self) -> bool:
"""Check if this is code content.""" """Check if this is code content."""
code_types = {"python", "javascript", "typescript", "go", "rust", "java", "c", "cpp"} code_types = {
"python",
"javascript",
"typescript",
"go",
"rust",
"java",
"c",
"cpp",
}
return self.file_type is not None and self.file_type.lower() in code_types return self.file_type is not None and self.file_type.lower() in code_types
def is_documentation(self) -> bool: def is_documentation(self) -> bool:

View File

@@ -55,11 +55,9 @@ class TaskContext(BaseContext):
constraints: list[str] = field(default_factory=list) constraints: list[str] = field(default_factory=list)
parent_task_id: str | None = field(default=None) parent_task_id: str | None = field(default=None)
def __post_init__(self) -> None: # Note: TaskContext should typically have HIGH priority,
"""Set high priority for task context.""" # but we don't auto-promote to allow explicit priority setting.
# Task context defaults to high priority # Use TaskContext.create() for default HIGH priority behavior.
if self.priority == ContextPriority.NORMAL.value:
self.priority = ContextPriority.HIGH.value
def get_type(self) -> ContextType: def get_type(self) -> ContextType:
"""Return TASK context type.""" """Return TASK context type."""

View File

@@ -56,7 +56,9 @@ class ToolContext(BaseContext):
"tool_name": self.tool_name, "tool_name": self.tool_name,
"tool_description": self.tool_description, "tool_description": self.tool_description,
"is_result": self.is_result, "is_result": self.is_result,
"result_status": self.result_status.value if self.result_status else None, "result_status": self.result_status.value
if self.result_status
else None,
"execution_time_ms": self.execution_time_ms, "execution_time_ms": self.execution_time_ms,
"parameters": self.parameters, "parameters": self.parameters,
"server_name": self.server_name, "server_name": self.server_name,
@@ -174,7 +176,9 @@ class ToolContext(BaseContext):
return cls( return cls(
content=content, content=content,
source=f"tool_result:{server_name}:{tool_name}" if server_name else f"tool_result:{tool_name}", source=f"tool_result:{server_name}:{tool_name}"
if server_name
else f"tool_result:{tool_name}",
tool_name=tool_name, tool_name=tool_name,
is_result=True, is_result=True,
result_status=status, result_status=status,

View File

@@ -0,0 +1,518 @@
"""Tests for model adapters."""
from app.services.context.adapters import (
ClaudeAdapter,
DefaultAdapter,
OpenAIAdapter,
get_adapter,
)
from app.services.context.types import (
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)
class TestGetAdapter:
"""Tests for get_adapter function."""
def test_claude_models(self) -> None:
"""Test that Claude models get ClaudeAdapter."""
assert isinstance(get_adapter("claude-3-sonnet"), ClaudeAdapter)
assert isinstance(get_adapter("claude-3-opus"), ClaudeAdapter)
assert isinstance(get_adapter("claude-3-haiku"), ClaudeAdapter)
assert isinstance(get_adapter("claude-2"), ClaudeAdapter)
assert isinstance(get_adapter("anthropic/claude-3-sonnet"), ClaudeAdapter)
def test_openai_models(self) -> None:
"""Test that OpenAI models get OpenAIAdapter."""
assert isinstance(get_adapter("gpt-4"), OpenAIAdapter)
assert isinstance(get_adapter("gpt-4-turbo"), OpenAIAdapter)
assert isinstance(get_adapter("gpt-3.5-turbo"), OpenAIAdapter)
assert isinstance(get_adapter("openai/gpt-4"), OpenAIAdapter)
assert isinstance(get_adapter("o1-mini"), OpenAIAdapter)
assert isinstance(get_adapter("o3-mini"), OpenAIAdapter)
def test_unknown_models(self) -> None:
"""Test that unknown models get DefaultAdapter."""
assert isinstance(get_adapter("llama-2"), DefaultAdapter)
assert isinstance(get_adapter("mistral-7b"), DefaultAdapter)
assert isinstance(get_adapter("custom-model"), DefaultAdapter)
class TestModelAdapterBase:
"""Tests for ModelAdapter base class."""
def test_get_type_order(self) -> None:
"""Test default type ordering."""
adapter = DefaultAdapter()
order = adapter.get_type_order()
assert order == [
ContextType.SYSTEM,
ContextType.TASK,
ContextType.KNOWLEDGE,
ContextType.CONVERSATION,
ContextType.TOOL,
]
def test_group_by_type(self) -> None:
"""Test grouping contexts by type."""
adapter = DefaultAdapter()
contexts = [
SystemContext(content="System", source="system"),
TaskContext(content="Task", source="task"),
KnowledgeContext(content="Knowledge", source="docs"),
SystemContext(content="System 2", source="system"),
]
grouped = adapter.group_by_type(contexts)
assert len(grouped[ContextType.SYSTEM]) == 2
assert len(grouped[ContextType.TASK]) == 1
assert len(grouped[ContextType.KNOWLEDGE]) == 1
assert ContextType.CONVERSATION not in grouped
def test_matches_model_default(self) -> None:
"""Test that DefaultAdapter matches all models."""
assert DefaultAdapter.matches_model("anything")
assert DefaultAdapter.matches_model("claude-3")
assert DefaultAdapter.matches_model("gpt-4")
class TestDefaultAdapter:
"""Tests for DefaultAdapter."""
def test_format_empty(self) -> None:
"""Test formatting empty context list."""
adapter = DefaultAdapter()
result = adapter.format([])
assert result == ""
def test_format_system(self) -> None:
"""Test formatting system context."""
adapter = DefaultAdapter()
contexts = [
SystemContext(content="You are helpful.", source="system"),
]
result = adapter.format(contexts)
assert "You are helpful." in result
def test_format_task(self) -> None:
"""Test formatting task context."""
adapter = DefaultAdapter()
contexts = [
TaskContext(content="Write a function.", source="task"),
]
result = adapter.format(contexts)
assert "Task:" in result
assert "Write a function." in result
def test_format_knowledge(self) -> None:
"""Test formatting knowledge context."""
adapter = DefaultAdapter()
contexts = [
KnowledgeContext(content="Documentation here.", source="docs"),
]
result = adapter.format(contexts)
assert "Reference Information:" in result
assert "Documentation here." in result
def test_format_conversation(self) -> None:
"""Test formatting conversation context."""
adapter = DefaultAdapter()
contexts = [
ConversationContext(
content="Hello!",
source="chat",
role=MessageRole.USER,
),
]
result = adapter.format(contexts)
assert "Previous Conversation:" in result
assert "Hello!" in result
def test_format_tool(self) -> None:
"""Test formatting tool context."""
adapter = DefaultAdapter()
contexts = [
ToolContext(
content="Result: success",
source="tool",
metadata={"tool_name": "search"},
),
]
result = adapter.format(contexts)
assert "Tool Results:" in result
assert "Result: success" in result
class TestClaudeAdapter:
"""Tests for ClaudeAdapter."""
def test_matches_model(self) -> None:
"""Test model matching."""
assert ClaudeAdapter.matches_model("claude-3-sonnet")
assert ClaudeAdapter.matches_model("claude-3-opus")
assert ClaudeAdapter.matches_model("anthropic/claude-3-haiku")
assert not ClaudeAdapter.matches_model("gpt-4")
assert not ClaudeAdapter.matches_model("llama-2")
def test_format_empty(self) -> None:
"""Test formatting empty context list."""
adapter = ClaudeAdapter()
result = adapter.format([])
assert result == ""
def test_format_system_uses_xml(self) -> None:
"""Test that system context uses XML tags."""
adapter = ClaudeAdapter()
contexts = [
SystemContext(content="You are helpful.", source="system"),
]
result = adapter.format(contexts)
assert "<system_instructions>" in result
assert "</system_instructions>" in result
assert "You are helpful." in result
def test_format_task_uses_xml(self) -> None:
"""Test that task context uses XML tags."""
adapter = ClaudeAdapter()
contexts = [
TaskContext(content="Write a function.", source="task"),
]
result = adapter.format(contexts)
assert "<current_task>" in result
assert "</current_task>" in result
assert "Write a function." in result
def test_format_knowledge_uses_document_tags(self) -> None:
"""Test that knowledge uses document XML tags."""
adapter = ClaudeAdapter()
contexts = [
KnowledgeContext(
content="Documentation here.",
source="docs/api.md",
relevance_score=0.9,
),
]
result = adapter.format(contexts)
assert "<reference_documents>" in result
assert "</reference_documents>" in result
assert '<document source="docs/api.md"' in result
assert "</document>" in result
assert "Documentation here." in result
def test_format_knowledge_with_score(self) -> None:
"""Test that knowledge includes relevance score."""
adapter = ClaudeAdapter()
contexts = [
KnowledgeContext(
content="Doc content.",
source="docs/api.md",
metadata={"relevance_score": 0.95},
),
]
result = adapter.format(contexts)
assert 'relevance="0.95"' in result
def test_format_conversation_uses_message_tags(self) -> None:
"""Test that conversation uses message XML tags."""
adapter = ClaudeAdapter()
contexts = [
ConversationContext(
content="Hello!",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
ConversationContext(
content="Hi there!",
source="chat",
role=MessageRole.ASSISTANT,
metadata={"role": "assistant"},
),
]
result = adapter.format(contexts)
assert "<conversation_history>" in result
assert "</conversation_history>" in result
assert '<message role="user">' in result
assert '<message role="assistant">' in result
assert "Hello!" in result
assert "Hi there!" in result
def test_format_tool_uses_tool_result_tags(self) -> None:
"""Test that tool results use tool_result XML tags."""
adapter = ClaudeAdapter()
contexts = [
ToolContext(
content='{"status": "ok"}',
source="tool",
metadata={"tool_name": "search", "status": "success"},
),
]
result = adapter.format(contexts)
assert "<tool_results>" in result
assert "</tool_results>" in result
assert '<tool_result name="search"' in result
assert 'status="success"' in result
assert "</tool_result>" in result
def test_format_multiple_types_in_order(self) -> None:
"""Test that multiple types are formatted in correct order."""
adapter = ClaudeAdapter()
contexts = [
KnowledgeContext(content="Knowledge", source="docs"),
SystemContext(content="System", source="system"),
TaskContext(content="Task", source="task"),
]
result = adapter.format(contexts)
# Find positions
system_pos = result.find("<system_instructions>")
task_pos = result.find("<current_task>")
knowledge_pos = result.find("<reference_documents>")
# Verify order
assert system_pos < task_pos < knowledge_pos
def test_escape_xml_in_source(self) -> None:
"""Test that XML special chars are escaped in source."""
adapter = ClaudeAdapter()
contexts = [
KnowledgeContext(
content="Doc content.",
source='path/with"quotes&stuff.md',
),
]
result = adapter.format(contexts)
assert "&quot;" in result
assert "&amp;" in result
class TestOpenAIAdapter:
"""Tests for OpenAIAdapter."""
def test_matches_model(self) -> None:
"""Test model matching."""
assert OpenAIAdapter.matches_model("gpt-4")
assert OpenAIAdapter.matches_model("gpt-4-turbo")
assert OpenAIAdapter.matches_model("gpt-3.5-turbo")
assert OpenAIAdapter.matches_model("openai/gpt-4")
assert OpenAIAdapter.matches_model("o1-preview")
assert OpenAIAdapter.matches_model("o3-mini")
assert not OpenAIAdapter.matches_model("claude-3")
assert not OpenAIAdapter.matches_model("llama-2")
def test_format_empty(self) -> None:
"""Test formatting empty context list."""
adapter = OpenAIAdapter()
result = adapter.format([])
assert result == ""
def test_format_system_plain(self) -> None:
"""Test that system content is plain."""
adapter = OpenAIAdapter()
contexts = [
SystemContext(content="You are helpful.", source="system"),
]
result = adapter.format(contexts)
# System content should be plain without headers
assert "You are helpful." in result
assert "##" not in result # No markdown headers for system
def test_format_task_uses_markdown(self) -> None:
"""Test that task uses markdown headers."""
adapter = OpenAIAdapter()
contexts = [
TaskContext(content="Write a function.", source="task"),
]
result = adapter.format(contexts)
assert "## Current Task" in result
assert "Write a function." in result
def test_format_knowledge_uses_markdown(self) -> None:
"""Test that knowledge uses markdown with source headers."""
adapter = OpenAIAdapter()
contexts = [
KnowledgeContext(
content="Documentation here.",
source="docs/api.md",
relevance_score=0.9,
),
]
result = adapter.format(contexts)
assert "## Reference Documents" in result
assert "### Source: docs/api.md" in result
assert "Documentation here." in result
def test_format_knowledge_with_score(self) -> None:
"""Test that knowledge includes relevance score."""
adapter = OpenAIAdapter()
contexts = [
KnowledgeContext(
content="Doc content.",
source="docs/api.md",
metadata={"relevance_score": 0.95},
),
]
result = adapter.format(contexts)
assert "(relevance: 0.95)" in result
def test_format_conversation_uses_bold_roles(self) -> None:
"""Test that conversation uses bold role labels."""
adapter = OpenAIAdapter()
contexts = [
ConversationContext(
content="Hello!",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
ConversationContext(
content="Hi there!",
source="chat",
role=MessageRole.ASSISTANT,
metadata={"role": "assistant"},
),
]
result = adapter.format(contexts)
assert "**USER**:" in result
assert "**ASSISTANT**:" in result
assert "Hello!" in result
assert "Hi there!" in result
def test_format_tool_uses_code_blocks(self) -> None:
"""Test that tool results use code blocks."""
adapter = OpenAIAdapter()
contexts = [
ToolContext(
content='{"status": "ok"}',
source="tool",
metadata={"tool_name": "search", "status": "success"},
),
]
result = adapter.format(contexts)
assert "## Recent Tool Results" in result
assert "### Tool: search (success)" in result
assert "```" in result # Code block
assert '{"status": "ok"}' in result
def test_format_multiple_types_in_order(self) -> None:
"""Test that multiple types are formatted in correct order."""
adapter = OpenAIAdapter()
contexts = [
KnowledgeContext(content="Knowledge", source="docs"),
SystemContext(content="System", source="system"),
TaskContext(content="Task", source="task"),
]
result = adapter.format(contexts)
# System comes first (no header), then task, then knowledge
system_pos = result.find("System")
task_pos = result.find("## Current Task")
knowledge_pos = result.find("## Reference Documents")
assert system_pos < task_pos < knowledge_pos
class TestAdapterIntegration:
"""Integration tests for adapters."""
def test_full_context_formatting_claude(self) -> None:
"""Test formatting a full set of contexts for Claude."""
adapter = ClaudeAdapter()
contexts = [
SystemContext(
content="You are an expert Python developer.",
source="system",
),
TaskContext(
content="Implement user authentication.",
source="task:AUTH-123",
),
KnowledgeContext(
content="JWT tokens provide stateless authentication...",
source="docs/auth/jwt.md",
relevance_score=0.9,
),
ConversationContext(
content="Can you help me implement JWT auth?",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
ToolContext(
content='{"file": "auth.py", "status": "created"}',
source="tool",
metadata={"tool_name": "file_create"},
),
]
result = adapter.format(contexts)
# Verify all sections present
assert "<system_instructions>" in result
assert "<current_task>" in result
assert "<reference_documents>" in result
assert "<conversation_history>" in result
assert "<tool_results>" in result
# Verify content
assert "expert Python developer" in result
assert "user authentication" in result
assert "JWT tokens" in result
assert "help me implement" in result
assert "file_create" in result
def test_full_context_formatting_openai(self) -> None:
"""Test formatting a full set of contexts for OpenAI."""
adapter = OpenAIAdapter()
contexts = [
SystemContext(
content="You are an expert Python developer.",
source="system",
),
TaskContext(
content="Implement user authentication.",
source="task:AUTH-123",
),
KnowledgeContext(
content="JWT tokens provide stateless authentication...",
source="docs/auth/jwt.md",
relevance_score=0.9,
),
ConversationContext(
content="Can you help me implement JWT auth?",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
ToolContext(
content='{"file": "auth.py", "status": "created"}',
source="tool",
metadata={"tool_name": "file_create"},
),
]
result = adapter.format(contexts)
# Verify all sections present
assert "## Current Task" in result
assert "## Reference Documents" in result
assert "## Recent Tool Results" in result
assert "**USER**:" in result
# Verify content
assert "expert Python developer" in result
assert "user authentication" in result
assert "JWT tokens" in result
assert "help me implement" in result
assert "file_create" in result

View File

@@ -0,0 +1,508 @@
"""Tests for context assembly pipeline."""
from datetime import UTC, datetime
import pytest
from app.services.context.assembly import ContextPipeline, PipelineMetrics
from app.services.context.budget import TokenBudget
from app.services.context.types import (
AssembledContext,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)
class TestPipelineMetrics:
"""Tests for PipelineMetrics dataclass."""
def test_creation(self) -> None:
"""Test metrics creation."""
metrics = PipelineMetrics()
assert metrics.total_contexts == 0
assert metrics.selected_contexts == 0
assert metrics.assembly_time_ms == 0.0
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
metrics = PipelineMetrics(
total_contexts=10,
selected_contexts=8,
excluded_contexts=2,
total_tokens=500,
assembly_time_ms=25.5,
)
metrics.end_time = datetime.now(UTC)
data = metrics.to_dict()
assert data["total_contexts"] == 10
assert data["selected_contexts"] == 8
assert data["excluded_contexts"] == 2
assert data["total_tokens"] == 500
assert data["assembly_time_ms"] == 25.5
assert "start_time" in data
assert "end_time" in data
class TestContextPipeline:
"""Tests for ContextPipeline."""
def test_creation(self) -> None:
"""Test pipeline creation."""
pipeline = ContextPipeline()
assert pipeline._calculator is not None
assert pipeline._scorer is not None
assert pipeline._ranker is not None
assert pipeline._compressor is not None
assert pipeline._allocator is not None
@pytest.mark.asyncio
async def test_assemble_empty_contexts(self) -> None:
"""Test assembling empty context list."""
pipeline = ContextPipeline()
result = await pipeline.assemble(
contexts=[],
query="test query",
model="claude-3-sonnet",
)
assert isinstance(result, AssembledContext)
assert result.context_count == 0
assert result.total_tokens == 0
@pytest.mark.asyncio
async def test_assemble_single_context(self) -> None:
"""Test assembling single context."""
pipeline = ContextPipeline()
contexts = [
SystemContext(
content="You are a helpful assistant.",
source="system",
)
]
result = await pipeline.assemble(
contexts=contexts,
query="help me",
model="claude-3-sonnet",
)
assert result.context_count == 1
assert result.total_tokens > 0
assert "helpful assistant" in result.content
@pytest.mark.asyncio
async def test_assemble_multiple_types(self) -> None:
"""Test assembling multiple context types."""
pipeline = ContextPipeline()
contexts = [
SystemContext(
content="You are a coding assistant.",
source="system",
),
TaskContext(
content="Implement a login feature.",
source="task",
),
KnowledgeContext(
content="Authentication best practices include...",
source="docs/auth.md",
relevance_score=0.8,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="implement login",
model="claude-3-sonnet",
)
assert result.context_count >= 1
assert result.total_tokens > 0
@pytest.mark.asyncio
async def test_assemble_with_custom_budget(self) -> None:
"""Test assembling with custom budget."""
pipeline = ContextPipeline()
budget = TokenBudget(
total=1000,
system=200,
task=200,
knowledge=400,
conversation=100,
tools=50,
response_reserve=50,
)
contexts = [
SystemContext(content="System prompt", source="system"),
TaskContext(content="Task description", source="task"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
custom_budget=budget,
)
assert result.context_count >= 1
@pytest.mark.asyncio
async def test_assemble_with_max_tokens(self) -> None:
"""Test assembling with max_tokens limit."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System prompt", source="system"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
max_tokens=5000,
)
assert "budget" in result.metadata
assert result.metadata["budget"]["total"] == 5000
@pytest.mark.asyncio
async def test_assemble_format_output(self) -> None:
"""Test formatted vs unformatted output."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System prompt", source="system"),
]
# Formatted (default)
result_formatted = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
format_output=True,
)
# Unformatted
result_raw = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
format_output=False,
)
# Formatted should have XML tags for Claude
assert "<system_instructions>" in result_formatted.content
# Raw should not
assert "<system_instructions>" not in result_raw.content
@pytest.mark.asyncio
async def test_assemble_metrics(self) -> None:
"""Test that metrics are populated."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System", source="system"),
TaskContext(content="Task", source="task"),
KnowledgeContext(
content="Knowledge",
source="docs",
relevance_score=0.9,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
assert "metrics" in result.metadata
metrics = result.metadata["metrics"]
assert metrics["total_contexts"] == 3
assert metrics["assembly_time_ms"] > 0
assert "scoring_time_ms" in metrics
assert "formatting_time_ms" in metrics
@pytest.mark.asyncio
async def test_assemble_with_compression_disabled(self) -> None:
"""Test assembling with compression disabled."""
pipeline = ContextPipeline()
contexts = [
KnowledgeContext(content="A" * 1000, source="docs"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
compress=False,
)
# Should still work, just no compression applied
assert result.context_count >= 0
class TestContextPipelineFormatting:
"""Tests for context formatting."""
@pytest.mark.asyncio
async def test_format_claude_uses_xml(self) -> None:
"""Test that Claude models use XML formatting."""
pipeline = ContextPipeline()
contexts = [
SystemContext(content="System prompt", source="system"),
TaskContext(content="Task", source="task"),
KnowledgeContext(
content="Knowledge",
source="docs",
relevance_score=0.9,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
# Claude should have XML tags
assert "<system_instructions>" in result.content or result.context_count == 0
@pytest.mark.asyncio
async def test_format_openai_uses_markdown(self) -> None:
"""Test that OpenAI models use markdown formatting."""
pipeline = ContextPipeline()
contexts = [
TaskContext(content="Task description", source="task"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4",
)
# OpenAI should have markdown headers
if result.context_count > 0 and "Task" in result.content:
assert "## Current Task" in result.content
@pytest.mark.asyncio
async def test_format_knowledge_claude(self) -> None:
"""Test knowledge formatting for Claude."""
pipeline = ContextPipeline()
contexts = [
KnowledgeContext(
content="Document content here",
source="docs/file.md",
relevance_score=0.9,
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
if result.context_count > 0:
assert "<reference_documents>" in result.content
assert "<document" in result.content
@pytest.mark.asyncio
async def test_format_conversation(self) -> None:
"""Test conversation formatting."""
pipeline = ContextPipeline()
contexts = [
ConversationContext(
content="Hello, how are you?",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
ConversationContext(
content="I'm doing great!",
source="chat",
role=MessageRole.ASSISTANT,
metadata={"role": "assistant"},
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
if result.context_count > 0:
assert "<conversation_history>" in result.content
assert (
'<message role="user">' in result.content
or 'role="user"' in result.content
)
@pytest.mark.asyncio
async def test_format_tool_results(self) -> None:
"""Test tool result formatting."""
pipeline = ContextPipeline()
contexts = [
ToolContext(
content="Tool output here",
source="tool",
metadata={"tool_name": "search"},
),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
if result.context_count > 0:
assert "<tool_results>" in result.content
class TestContextPipelineIntegration:
"""Integration tests for full pipeline."""
@pytest.mark.asyncio
async def test_full_pipeline_workflow(self) -> None:
"""Test complete pipeline workflow."""
pipeline = ContextPipeline()
# Create realistic context mix
contexts = [
SystemContext(
content="You are an expert Python developer.",
source="system",
),
TaskContext(
content="Implement a user authentication system.",
source="task:AUTH-123",
),
KnowledgeContext(
content="JWT tokens provide stateless authentication...",
source="docs/auth/jwt.md",
relevance_score=0.9,
),
KnowledgeContext(
content="OAuth 2.0 is an authorization framework...",
source="docs/auth/oauth.md",
relevance_score=0.7,
),
ConversationContext(
content="Can you help me implement JWT auth?",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
]
result = await pipeline.assemble(
contexts=contexts,
query="implement JWT authentication",
model="claude-3-sonnet",
)
# Verify result
assert isinstance(result, AssembledContext)
assert result.context_count > 0
assert result.total_tokens > 0
assert result.assembly_time_ms > 0
assert result.model == "claude-3-sonnet"
assert len(result.content) > 0
# Verify metrics
assert "metrics" in result.metadata
assert "query" in result.metadata
assert "budget" in result.metadata
@pytest.mark.asyncio
async def test_context_type_ordering(self) -> None:
"""Test that contexts are ordered by type correctly."""
pipeline = ContextPipeline()
# Add in random order
contexts = [
KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9),
ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}),
SystemContext(content="System", source="system"),
ConversationContext(
content="Chat",
source="chat",
role=MessageRole.USER,
metadata={"role": "user"},
),
TaskContext(content="Task", source="task"),
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="claude-3-sonnet",
)
# For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool
content = result.content
if result.context_count > 0:
# Find positions (if they exist)
system_pos = content.find("system_instructions")
task_pos = content.find("current_task")
knowledge_pos = content.find("reference_documents")
conversation_pos = content.find("conversation_history")
tool_pos = content.find("tool_results")
# Verify ordering (only check if both exist)
if system_pos >= 0 and task_pos >= 0:
assert system_pos < task_pos
if task_pos >= 0 and knowledge_pos >= 0:
assert task_pos < knowledge_pos
if knowledge_pos >= 0 and conversation_pos >= 0:
assert knowledge_pos < conversation_pos
if conversation_pos >= 0 and tool_pos >= 0:
assert conversation_pos < tool_pos
@pytest.mark.asyncio
async def test_excluded_contexts_tracked(self) -> None:
"""Test that excluded contexts are tracked in result."""
pipeline = ContextPipeline()
# Create many contexts to force some exclusions
contexts = [
KnowledgeContext(
content="A" * 500, # Large content
source=f"docs/{i}",
relevance_score=0.1 + (i * 0.05),
)
for i in range(10)
]
result = await pipeline.assemble(
contexts=contexts,
query="test",
model="gpt-4", # Smaller context window
max_tokens=1000, # Limited budget
)
# Should have excluded some
assert result.excluded_count >= 0
assert result.context_count + result.excluded_count <= len(contexts)

View File

@@ -0,0 +1,533 @@
"""Tests for token budget management."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.context.budget import (
BudgetAllocator,
TokenBudget,
TokenCalculator,
)
from app.services.context.config import ContextSettings
from app.services.context.exceptions import BudgetExceededError
from app.services.context.types import ContextType
class TestTokenBudget:
"""Tests for TokenBudget dataclass."""
def test_creation(self) -> None:
"""Test basic budget creation."""
budget = TokenBudget(total=10000)
assert budget.total == 10000
assert budget.system == 0
assert budget.total_used() == 0
def test_creation_with_allocations(self) -> None:
"""Test budget creation with allocations."""
budget = TokenBudget(
total=10000,
system=500,
task=1000,
knowledge=4000,
conversation=2000,
tools=500,
response_reserve=1500,
buffer=500,
)
assert budget.system == 500
assert budget.knowledge == 4000
assert budget.response_reserve == 1500
def test_get_allocation(self) -> None:
"""Test getting allocation for a type."""
budget = TokenBudget(
total=10000,
system=500,
knowledge=4000,
)
assert budget.get_allocation(ContextType.SYSTEM) == 500
assert budget.get_allocation(ContextType.KNOWLEDGE) == 4000
assert budget.get_allocation("system") == 500
def test_remaining(self) -> None:
"""Test remaining budget calculation."""
budget = TokenBudget(
total=10000,
system=500,
knowledge=4000,
)
# Initially full
assert budget.remaining(ContextType.SYSTEM) == 500
assert budget.remaining(ContextType.KNOWLEDGE) == 4000
# After allocation
budget.allocate(ContextType.SYSTEM, 200)
assert budget.remaining(ContextType.SYSTEM) == 300
def test_can_fit(self) -> None:
"""Test can_fit check."""
budget = TokenBudget(
total=10000,
system=500,
knowledge=4000,
)
assert budget.can_fit(ContextType.SYSTEM, 500) is True
assert budget.can_fit(ContextType.SYSTEM, 501) is False
assert budget.can_fit(ContextType.KNOWLEDGE, 4000) is True
def test_allocate_success(self) -> None:
"""Test successful allocation."""
budget = TokenBudget(
total=10000,
system=500,
)
result = budget.allocate(ContextType.SYSTEM, 200)
assert result is True
assert budget.get_used(ContextType.SYSTEM) == 200
assert budget.remaining(ContextType.SYSTEM) == 300
def test_allocate_exceeds_budget(self) -> None:
"""Test allocation exceeding budget."""
budget = TokenBudget(
total=10000,
system=500,
)
with pytest.raises(BudgetExceededError) as exc_info:
budget.allocate(ContextType.SYSTEM, 600)
assert exc_info.value.allocated == 500
assert exc_info.value.requested == 600
def test_allocate_force(self) -> None:
"""Test forced allocation exceeding budget."""
budget = TokenBudget(
total=10000,
system=500,
)
# Force should allow exceeding
result = budget.allocate(ContextType.SYSTEM, 600, force=True)
assert result is True
assert budget.get_used(ContextType.SYSTEM) == 600
def test_deallocate(self) -> None:
"""Test deallocation."""
budget = TokenBudget(
total=10000,
system=500,
)
budget.allocate(ContextType.SYSTEM, 300)
assert budget.get_used(ContextType.SYSTEM) == 300
budget.deallocate(ContextType.SYSTEM, 100)
assert budget.get_used(ContextType.SYSTEM) == 200
def test_deallocate_below_zero(self) -> None:
"""Test deallocation doesn't go below zero."""
budget = TokenBudget(
total=10000,
system=500,
)
budget.allocate(ContextType.SYSTEM, 100)
budget.deallocate(ContextType.SYSTEM, 200)
assert budget.get_used(ContextType.SYSTEM) == 0
def test_total_remaining(self) -> None:
"""Test total remaining calculation."""
budget = TokenBudget(
total=10000,
system=500,
knowledge=4000,
response_reserve=1500,
buffer=500,
)
# Usable = total - response_reserve - buffer = 10000 - 1500 - 500 = 8000
assert budget.total_remaining() == 8000
# After allocation
budget.allocate(ContextType.SYSTEM, 200)
assert budget.total_remaining() == 7800
def test_utilization(self) -> None:
"""Test utilization calculation."""
budget = TokenBudget(
total=10000,
system=500,
response_reserve=1500,
buffer=500,
)
# No usage = 0%
assert budget.utilization(ContextType.SYSTEM) == 0.0
# Half used = 50%
budget.allocate(ContextType.SYSTEM, 250)
assert budget.utilization(ContextType.SYSTEM) == 0.5
# Total utilization
assert budget.utilization() == 250 / 8000 # 250 / (10000 - 1500 - 500)
def test_reset(self) -> None:
"""Test reset clears usage."""
budget = TokenBudget(
total=10000,
system=500,
)
budget.allocate(ContextType.SYSTEM, 300)
assert budget.get_used(ContextType.SYSTEM) == 300
budget.reset()
assert budget.get_used(ContextType.SYSTEM) == 0
assert budget.total_used() == 0
def test_to_dict(self) -> None:
"""Test to_dict conversion."""
budget = TokenBudget(
total=10000,
system=500,
task=1000,
knowledge=4000,
)
budget.allocate(ContextType.SYSTEM, 200)
data = budget.to_dict()
assert data["total"] == 10000
assert data["allocations"]["system"] == 500
assert data["used"]["system"] == 200
assert data["remaining"]["system"] == 300
class TestBudgetAllocator:
"""Tests for BudgetAllocator."""
def test_create_budget(self) -> None:
"""Test budget creation with default allocations."""
allocator = BudgetAllocator()
budget = allocator.create_budget(100000)
assert budget.total == 100000
assert budget.system == 5000 # 5%
assert budget.task == 10000 # 10%
assert budget.knowledge == 40000 # 40%
assert budget.conversation == 20000 # 20%
assert budget.tools == 5000 # 5%
assert budget.response_reserve == 15000 # 15%
assert budget.buffer == 5000 # 5%
def test_create_budget_custom_allocations(self) -> None:
"""Test budget creation with custom allocations."""
allocator = BudgetAllocator()
budget = allocator.create_budget(
100000,
custom_allocations={
"system": 0.10,
"task": 0.10,
"knowledge": 0.30,
"conversation": 0.25,
"tools": 0.05,
"response": 0.15,
"buffer": 0.05,
},
)
assert budget.system == 10000 # 10%
assert budget.knowledge == 30000 # 30%
def test_create_budget_for_model(self) -> None:
"""Test budget creation for specific model."""
allocator = BudgetAllocator()
# Claude models have 200k context
budget = allocator.create_budget_for_model("claude-3-sonnet")
assert budget.total == 200000
# GPT-4 has 8k context
budget = allocator.create_budget_for_model("gpt-4")
assert budget.total == 8192
# GPT-4-turbo has 128k context
budget = allocator.create_budget_for_model("gpt-4-turbo")
assert budget.total == 128000
def test_get_model_context_size(self) -> None:
"""Test model context size lookup."""
allocator = BudgetAllocator()
# Known models
assert allocator.get_model_context_size("claude-3-opus") == 200000
assert allocator.get_model_context_size("gpt-4") == 8192
assert allocator.get_model_context_size("gemini-1.5-pro") == 2000000
# Unknown model gets default
assert allocator.get_model_context_size("unknown-model") == 8192
def test_adjust_budget(self) -> None:
"""Test budget adjustment."""
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
original_system = budget.system
original_buffer = budget.buffer
# Increase system by taking from buffer
budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 200)
assert budget.system == original_system + 200
assert budget.buffer == original_buffer - 200
def test_adjust_budget_limited_by_buffer(self) -> None:
"""Test that adjustment is limited by buffer size."""
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
original_buffer = budget.buffer
# Try to increase more than buffer allows
budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 10000)
# Should only increase by buffer amount
assert budget.buffer == 0
assert budget.system <= original_buffer + budget.system
def test_rebalance_budget(self) -> None:
"""Test budget rebalancing."""
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
# Use most of knowledge budget
budget.allocate(ContextType.KNOWLEDGE, 3500)
# Rebalance prioritizing knowledge
budget = allocator.rebalance_budget(
budget,
prioritize=[ContextType.KNOWLEDGE],
)
# Knowledge should have gotten more tokens
# (This is a fuzzy test - just check it runs)
assert budget is not None
class TestTokenCalculator:
"""Tests for TokenCalculator."""
def test_estimate_tokens(self) -> None:
"""Test token estimation."""
calc = TokenCalculator()
# Empty string
assert calc.estimate_tokens("") == 0
# Short text (~4 chars per token)
text = "This is a test message"
estimate = calc.estimate_tokens(text)
assert 4 <= estimate <= 8
def test_estimate_tokens_model_specific(self) -> None:
"""Test model-specific estimation ratios."""
calc = TokenCalculator()
text = "a" * 100
# Claude uses 3.5 chars per token
claude_estimate = calc.estimate_tokens(text, "claude-3-sonnet")
# GPT uses 4.0 chars per token
gpt_estimate = calc.estimate_tokens(text, "gpt-4")
# Claude should estimate more tokens (smaller ratio)
assert claude_estimate >= gpt_estimate
@pytest.mark.asyncio
async def test_count_tokens_no_mcp(self) -> None:
"""Test token counting without MCP (fallback to estimation)."""
calc = TokenCalculator()
text = "This is a test"
count = await calc.count_tokens(text)
# Should use estimation
assert count > 0
@pytest.mark.asyncio
async def test_count_tokens_with_mcp_success(self) -> None:
"""Test token counting with MCP integration."""
# Mock MCP manager
mock_mcp = MagicMock()
mock_result = MagicMock()
mock_result.success = True
mock_result.data = {"token_count": 42}
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
calc = TokenCalculator(mcp_manager=mock_mcp)
count = await calc.count_tokens("test text")
assert count == 42
mock_mcp.call_tool.assert_called_once()
@pytest.mark.asyncio
async def test_count_tokens_with_mcp_failure(self) -> None:
"""Test fallback when MCP fails."""
# Mock MCP manager that fails
mock_mcp = MagicMock()
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
calc = TokenCalculator(mcp_manager=mock_mcp)
count = await calc.count_tokens("test text")
# Should fall back to estimation
assert count > 0
@pytest.mark.asyncio
async def test_count_tokens_caching(self) -> None:
"""Test that token counts are cached."""
mock_mcp = MagicMock()
mock_result = MagicMock()
mock_result.success = True
mock_result.data = {"token_count": 42}
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
calc = TokenCalculator(mcp_manager=mock_mcp)
# First call
count1 = await calc.count_tokens("test text")
# Second call (should use cache)
count2 = await calc.count_tokens("test text")
assert count1 == count2 == 42
# MCP should only be called once
assert mock_mcp.call_tool.call_count == 1
@pytest.mark.asyncio
async def test_count_tokens_batch(self) -> None:
"""Test batch token counting."""
calc = TokenCalculator()
texts = ["Hello", "World", "Test message here"]
counts = await calc.count_tokens_batch(texts)
assert len(counts) == 3
assert all(c > 0 for c in counts)
def test_cache_stats(self) -> None:
"""Test cache statistics."""
calc = TokenCalculator()
stats = calc.get_cache_stats()
assert stats["enabled"] is True
assert stats["size"] == 0
assert stats["hits"] == 0
assert stats["misses"] == 0
@pytest.mark.asyncio
async def test_cache_hit_rate(self) -> None:
"""Test cache hit rate tracking."""
calc = TokenCalculator()
# Make some calls
await calc.count_tokens("text1")
await calc.count_tokens("text2")
await calc.count_tokens("text1") # Cache hit
stats = calc.get_cache_stats()
assert stats["hits"] == 1
assert stats["misses"] == 2
def test_clear_cache(self) -> None:
"""Test cache clearing."""
calc = TokenCalculator()
calc._cache["test"] = 100
calc._cache_hits = 5
calc.clear_cache()
assert len(calc._cache) == 0
assert calc._cache_hits == 0
def test_set_mcp_manager(self) -> None:
"""Test setting MCP manager after initialization."""
calc = TokenCalculator()
assert calc._mcp is None
mock_mcp = MagicMock()
calc.set_mcp_manager(mock_mcp)
assert calc._mcp is mock_mcp
@pytest.mark.asyncio
async def test_parse_token_count_formats(self) -> None:
"""Test parsing different token count response formats."""
calc = TokenCalculator()
# Dict with token_count
assert calc._parse_token_count({"token_count": 42}) == 42
# Dict with tokens
assert calc._parse_token_count({"tokens": 42}) == 42
# Dict with count
assert calc._parse_token_count({"count": 42}) == 42
# Direct int
assert calc._parse_token_count(42) == 42
# JSON string
assert calc._parse_token_count('{"token_count": 42}') == 42
# Invalid
assert calc._parse_token_count("invalid") is None
class TestBudgetIntegration:
"""Integration tests for budget management."""
@pytest.mark.asyncio
async def test_full_budget_workflow(self) -> None:
"""Test complete budget allocation workflow."""
# Create settings and allocator
settings = ContextSettings()
allocator = BudgetAllocator(settings)
# Create budget for Claude
budget = allocator.create_budget_for_model("claude-3-sonnet")
assert budget.total == 200000
# Create calculator (without MCP for test)
calc = TokenCalculator()
# Simulate context allocation
system_text = "You are a helpful assistant." * 10
system_tokens = await calc.count_tokens(system_text)
# Allocate
assert budget.can_fit(ContextType.SYSTEM, system_tokens)
budget.allocate(ContextType.SYSTEM, system_tokens)
# Check state
assert budget.get_used(ContextType.SYSTEM) == system_tokens
assert budget.remaining(ContextType.SYSTEM) == budget.system - system_tokens
@pytest.mark.asyncio
async def test_budget_overflow_handling(self) -> None:
"""Test handling budget overflow."""
allocator = BudgetAllocator()
budget = allocator.create_budget(1000) # Small budget
# Try to allocate more than available
with pytest.raises(BudgetExceededError):
budget.allocate(ContextType.KNOWLEDGE, 500)
# Force allocation should work
budget.allocate(ContextType.KNOWLEDGE, 500, force=True)
assert budget.get_used(ContextType.KNOWLEDGE) == 500

View File

@@ -0,0 +1,479 @@
"""Tests for context cache module."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.context.cache import ContextCache
from app.services.context.config import ContextSettings
from app.services.context.exceptions import CacheError
from app.services.context.types import (
AssembledContext,
ContextPriority,
KnowledgeContext,
SystemContext,
TaskContext,
)
class TestContextCacheBasics:
"""Basic tests for ContextCache."""
def test_creation(self) -> None:
"""Test cache creation without Redis."""
cache = ContextCache()
assert cache._redis is None
assert not cache.is_enabled
def test_creation_with_settings(self) -> None:
"""Test cache creation with custom settings."""
settings = ContextSettings(
cache_prefix="test",
cache_ttl_seconds=60,
)
cache = ContextCache(settings=settings)
assert cache._prefix == "test"
assert cache._ttl == 60
def test_set_redis(self) -> None:
"""Test setting Redis connection."""
cache = ContextCache()
mock_redis = MagicMock()
cache.set_redis(mock_redis)
assert cache._redis is mock_redis
def test_is_enabled(self) -> None:
"""Test is_enabled property."""
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(settings=settings)
assert not cache.is_enabled # No Redis
cache.set_redis(MagicMock())
assert cache.is_enabled
# Disabled in settings
settings2 = ContextSettings(cache_enabled=False)
cache2 = ContextCache(redis=MagicMock(), settings=settings2)
assert not cache2.is_enabled
def test_cache_key(self) -> None:
"""Test cache key generation."""
cache = ContextCache()
key = cache._cache_key("assembled", "abc123")
assert key == "ctx:assembled:abc123"
def test_hash_content(self) -> None:
"""Test content hashing."""
hash1 = ContextCache._hash_content("hello world")
hash2 = ContextCache._hash_content("hello world")
hash3 = ContextCache._hash_content("different")
assert hash1 == hash2
assert hash1 != hash3
assert len(hash1) == 32
class TestFingerprintComputation:
"""Tests for fingerprint computation."""
def test_compute_fingerprint(self) -> None:
"""Test fingerprint computation."""
cache = ContextCache()
contexts = [
SystemContext(content="System", source="system"),
TaskContext(content="Task", source="task"),
]
fp1 = cache.compute_fingerprint(contexts, "query", "claude-3")
fp2 = cache.compute_fingerprint(contexts, "query", "claude-3")
fp3 = cache.compute_fingerprint(contexts, "different", "claude-3")
assert fp1 == fp2 # Same inputs = same fingerprint
assert fp1 != fp3 # Different query = different fingerprint
assert len(fp1) == 32
def test_fingerprint_includes_priority(self) -> None:
"""Test that fingerprint changes with priority."""
cache = ContextCache()
# Use KnowledgeContext since SystemContext has __post_init__ that may override
ctx1 = [
KnowledgeContext(
content="Knowledge",
source="docs",
priority=ContextPriority.NORMAL.value,
)
]
ctx2 = [
KnowledgeContext(
content="Knowledge",
source="docs",
priority=ContextPriority.HIGH.value,
)
]
fp1 = cache.compute_fingerprint(ctx1, "query", "claude-3")
fp2 = cache.compute_fingerprint(ctx2, "query", "claude-3")
assert fp1 != fp2
def test_fingerprint_includes_model(self) -> None:
"""Test that fingerprint changes with model."""
cache = ContextCache()
contexts = [SystemContext(content="System", source="system")]
fp1 = cache.compute_fingerprint(contexts, "query", "claude-3")
fp2 = cache.compute_fingerprint(contexts, "query", "gpt-4")
assert fp1 != fp2
class TestMemoryCache:
"""Tests for in-memory caching."""
def test_memory_cache_fallback(self) -> None:
"""Test memory cache when Redis unavailable."""
cache = ContextCache()
# Should use memory cache
cache._set_memory("test-key", "42")
assert "test-key" in cache._memory_cache
assert cache._memory_cache["test-key"][0] == "42"
def test_memory_cache_eviction(self) -> None:
"""Test memory cache eviction."""
cache = ContextCache()
cache._max_memory_items = 10
# Fill cache
for i in range(15):
cache._set_memory(f"key-{i}", f"value-{i}")
# Should have evicted some items
assert len(cache._memory_cache) < 15
class TestAssembledContextCache:
"""Tests for assembled context caching."""
@pytest.mark.asyncio
async def test_get_assembled_no_redis(self) -> None:
"""Test get_assembled without Redis returns None."""
cache = ContextCache()
result = await cache.get_assembled("fingerprint")
assert result is None
@pytest.mark.asyncio
async def test_get_assembled_not_found(self) -> None:
"""Test get_assembled when key not found."""
mock_redis = AsyncMock()
mock_redis.get.return_value = None
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
result = await cache.get_assembled("fingerprint")
assert result is None
@pytest.mark.asyncio
async def test_get_assembled_found(self) -> None:
"""Test get_assembled when key found."""
# Create a context
ctx = AssembledContext(
content="Test content",
total_tokens=100,
context_count=2,
)
mock_redis = AsyncMock()
mock_redis.get.return_value = ctx.to_json()
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
result = await cache.get_assembled("fingerprint")
assert result is not None
assert result.content == "Test content"
assert result.total_tokens == 100
assert result.cache_hit is True
assert result.cache_key == "fingerprint"
@pytest.mark.asyncio
async def test_set_assembled(self) -> None:
"""Test set_assembled."""
mock_redis = AsyncMock()
settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=60)
cache = ContextCache(redis=mock_redis, settings=settings)
ctx = AssembledContext(
content="Test content",
total_tokens=100,
context_count=2,
)
await cache.set_assembled("fingerprint", ctx)
mock_redis.setex.assert_called_once()
call_args = mock_redis.setex.call_args
assert call_args[0][0] == "ctx:assembled:fingerprint"
assert call_args[0][1] == 60 # TTL
@pytest.mark.asyncio
async def test_set_assembled_custom_ttl(self) -> None:
"""Test set_assembled with custom TTL."""
mock_redis = AsyncMock()
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
ctx = AssembledContext(
content="Test",
total_tokens=10,
context_count=1,
)
await cache.set_assembled("fp", ctx, ttl=120)
call_args = mock_redis.setex.call_args
assert call_args[0][1] == 120
@pytest.mark.asyncio
async def test_cache_error_on_get(self) -> None:
"""Test CacheError raised on Redis error."""
mock_redis = AsyncMock()
mock_redis.get.side_effect = Exception("Redis error")
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
with pytest.raises(CacheError):
await cache.get_assembled("fingerprint")
@pytest.mark.asyncio
async def test_cache_error_on_set(self) -> None:
"""Test CacheError raised on Redis error."""
mock_redis = AsyncMock()
mock_redis.setex.side_effect = Exception("Redis error")
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
ctx = AssembledContext(
content="Test",
total_tokens=10,
context_count=1,
)
with pytest.raises(CacheError):
await cache.set_assembled("fp", ctx)
class TestTokenCountCache:
"""Tests for token count caching."""
@pytest.mark.asyncio
async def test_get_token_count_memory_fallback(self) -> None:
"""Test get_token_count uses memory cache."""
cache = ContextCache()
# Set in memory
key = cache._cache_key("tokens", "default", cache._hash_content("hello"))
cache._set_memory(key, "42")
result = await cache.get_token_count("hello")
assert result == 42
@pytest.mark.asyncio
async def test_set_token_count_memory(self) -> None:
"""Test set_token_count stores in memory."""
cache = ContextCache()
await cache.set_token_count("hello", 42)
result = await cache.get_token_count("hello")
assert result == 42
@pytest.mark.asyncio
async def test_set_token_count_with_model(self) -> None:
"""Test set_token_count with model-specific tokenization."""
mock_redis = AsyncMock()
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
await cache.set_token_count("hello", 42, model="claude-3")
await cache.set_token_count("hello", 50, model="gpt-4")
# Different models should have different keys
assert mock_redis.setex.call_count == 2
calls = mock_redis.setex.call_args_list
key1 = calls[0][0][0]
key2 = calls[1][0][0]
assert "claude-3" in key1
assert "gpt-4" in key2
class TestScoreCache:
"""Tests for score caching."""
@pytest.mark.asyncio
async def test_get_score_memory_fallback(self) -> None:
"""Test get_score uses memory cache."""
cache = ContextCache()
# Set in memory
query_hash = cache._hash_content("query")[:16]
key = cache._cache_key("score", "relevance", "ctx-123", query_hash)
cache._set_memory(key, "0.85")
result = await cache.get_score("relevance", "ctx-123", "query")
assert result == 0.85
@pytest.mark.asyncio
async def test_set_score_memory(self) -> None:
"""Test set_score stores in memory."""
cache = ContextCache()
await cache.set_score("relevance", "ctx-123", "query", 0.85)
result = await cache.get_score("relevance", "ctx-123", "query")
assert result == 0.85
@pytest.mark.asyncio
async def test_set_score_with_redis(self) -> None:
"""Test set_score with Redis."""
mock_redis = AsyncMock()
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
await cache.set_score("relevance", "ctx-123", "query", 0.85)
mock_redis.setex.assert_called_once()
class TestCacheInvalidation:
"""Tests for cache invalidation."""
@pytest.mark.asyncio
async def test_invalidate_pattern(self) -> None:
"""Test invalidate with pattern."""
mock_redis = AsyncMock()
# Set up scan_iter to return matching keys
async def mock_scan_iter(match=None):
for key in ["ctx:assembled:1", "ctx:assembled:2"]:
yield key
mock_redis.scan_iter = mock_scan_iter
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
deleted = await cache.invalidate("assembled:*")
assert deleted == 2
assert mock_redis.delete.call_count == 2
@pytest.mark.asyncio
async def test_clear_all(self) -> None:
"""Test clear_all."""
mock_redis = AsyncMock()
async def mock_scan_iter(match=None):
for key in ["ctx:1", "ctx:2", "ctx:3"]:
yield key
mock_redis.scan_iter = mock_scan_iter
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
# Add to memory cache
cache._set_memory("test", "value")
assert len(cache._memory_cache) > 0
deleted = await cache.clear_all()
assert deleted == 3
assert len(cache._memory_cache) == 0
class TestCacheStats:
"""Tests for cache statistics."""
@pytest.mark.asyncio
async def test_get_stats_no_redis(self) -> None:
"""Test get_stats without Redis."""
cache = ContextCache()
cache._set_memory("key", "value")
stats = await cache.get_stats()
assert stats["enabled"] is True
assert stats["redis_available"] is False
assert stats["memory_items"] == 1
@pytest.mark.asyncio
async def test_get_stats_with_redis(self) -> None:
"""Test get_stats with Redis."""
mock_redis = AsyncMock()
mock_redis.info.return_value = {"used_memory_human": "1.5M"}
settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=300)
cache = ContextCache(redis=mock_redis, settings=settings)
stats = await cache.get_stats()
assert stats["enabled"] is True
assert stats["redis_available"] is True
assert stats["ttl_seconds"] == 300
assert stats["redis_memory_used"] == "1.5M"
class TestCacheIntegration:
"""Integration tests for cache."""
@pytest.mark.asyncio
async def test_full_workflow(self) -> None:
"""Test complete cache workflow."""
mock_redis = AsyncMock()
mock_redis.get.return_value = None
settings = ContextSettings(cache_enabled=True)
cache = ContextCache(redis=mock_redis, settings=settings)
contexts = [
SystemContext(content="System", source="system"),
KnowledgeContext(content="Knowledge", source="docs"),
]
# Compute fingerprint
fp = cache.compute_fingerprint(contexts, "query", "claude-3")
assert len(fp) == 32
# Check cache (miss)
result = await cache.get_assembled(fp)
assert result is None
# Create and cache assembled context
assembled = AssembledContext(
content="Assembled content",
total_tokens=100,
context_count=2,
model="claude-3",
)
await cache.set_assembled(fp, assembled)
# Verify setex was called
mock_redis.setex.assert_called_once()
# Mock cache hit
mock_redis.get.return_value = assembled.to_json()
result = await cache.get_assembled(fp)
assert result is not None
assert result.cache_hit is True
assert result.content == "Assembled content"

View File

@@ -0,0 +1,294 @@
"""Tests for context compression module."""
import pytest
from app.services.context.budget import BudgetAllocator
from app.services.context.compression import (
ContextCompressor,
TruncationResult,
TruncationStrategy,
)
from app.services.context.types import (
ContextType,
KnowledgeContext,
TaskContext,
)
class TestTruncationResult:
"""Tests for TruncationResult dataclass."""
def test_creation(self) -> None:
"""Test basic creation."""
result = TruncationResult(
original_tokens=100,
truncated_tokens=50,
content="Truncated content",
truncated=True,
truncation_ratio=0.5,
)
assert result.original_tokens == 100
assert result.truncated_tokens == 50
assert result.truncated is True
assert result.truncation_ratio == 0.5
def test_tokens_saved(self) -> None:
"""Test tokens_saved property."""
result = TruncationResult(
original_tokens=100,
truncated_tokens=40,
content="Test",
truncated=True,
truncation_ratio=0.6,
)
assert result.tokens_saved == 60
def test_no_truncation(self) -> None:
"""Test when no truncation needed."""
result = TruncationResult(
original_tokens=50,
truncated_tokens=50,
content="Full content",
truncated=False,
truncation_ratio=0.0,
)
assert result.tokens_saved == 0
assert result.truncated is False
class TestTruncationStrategy:
"""Tests for TruncationStrategy."""
def test_creation(self) -> None:
"""Test strategy creation."""
strategy = TruncationStrategy()
assert strategy._preserve_ratio_start == 0.7
assert strategy._min_content_length == 100
def test_creation_with_params(self) -> None:
"""Test strategy creation with custom params."""
strategy = TruncationStrategy(
preserve_ratio_start=0.5,
min_content_length=50,
)
assert strategy._preserve_ratio_start == 0.5
assert strategy._min_content_length == 50
@pytest.mark.asyncio
async def test_truncate_empty_content(self) -> None:
"""Test truncating empty content."""
strategy = TruncationStrategy()
result = await strategy.truncate_to_tokens("", max_tokens=100)
assert result.original_tokens == 0
assert result.truncated_tokens == 0
assert result.content == ""
assert result.truncated is False
@pytest.mark.asyncio
async def test_truncate_content_within_limit(self) -> None:
"""Test content that fits within limit."""
strategy = TruncationStrategy()
content = "Short content"
result = await strategy.truncate_to_tokens(content, max_tokens=100)
assert result.content == content
assert result.truncated is False
@pytest.mark.asyncio
async def test_truncate_end_strategy(self) -> None:
"""Test end truncation strategy."""
strategy = TruncationStrategy()
content = "A" * 1000 # Long content
result = await strategy.truncate_to_tokens(
content, max_tokens=50, strategy="end"
)
assert result.truncated is True
assert len(result.content) < len(content)
assert strategy.truncation_marker in result.content
@pytest.mark.asyncio
async def test_truncate_middle_strategy(self) -> None:
"""Test middle truncation strategy."""
strategy = TruncationStrategy(preserve_ratio_start=0.6)
content = "START " + "A" * 500 + " END"
result = await strategy.truncate_to_tokens(
content, max_tokens=50, strategy="middle"
)
assert result.truncated is True
assert strategy.truncation_marker in result.content
@pytest.mark.asyncio
async def test_truncate_sentence_strategy(self) -> None:
"""Test sentence-aware truncation strategy."""
strategy = TruncationStrategy()
content = "First sentence. Second sentence. Third sentence. Fourth sentence."
result = await strategy.truncate_to_tokens(
content, max_tokens=10, strategy="sentence"
)
assert result.truncated is True
# Should cut at sentence boundary
assert (
result.content.endswith(".") or strategy.truncation_marker in result.content
)
class TestContextCompressor:
"""Tests for ContextCompressor."""
def test_creation(self) -> None:
"""Test compressor creation."""
compressor = ContextCompressor()
assert compressor._truncation is not None
@pytest.mark.asyncio
async def test_compress_context_within_limit(self) -> None:
"""Test compressing context that already fits."""
compressor = ContextCompressor()
context = KnowledgeContext(
content="Short content",
source="docs",
)
context.token_count = 5
result = await compressor.compress_context(context, max_tokens=100)
# Should return same context unmodified
assert result.content == "Short content"
assert result.metadata.get("truncated") is not True
@pytest.mark.asyncio
async def test_compress_context_exceeds_limit(self) -> None:
"""Test compressing context that exceeds limit."""
compressor = ContextCompressor()
context = KnowledgeContext(
content="A" * 500,
source="docs",
)
context.token_count = 125 # Approximately 500/4
result = await compressor.compress_context(context, max_tokens=20)
assert result.metadata.get("truncated") is True
assert result.metadata.get("original_tokens") == 125
assert len(result.content) < 500
@pytest.mark.asyncio
async def test_compress_contexts_batch(self) -> None:
"""Test compressing multiple contexts."""
compressor = ContextCompressor()
allocator = BudgetAllocator()
budget = allocator.create_budget(1000)
contexts = [
KnowledgeContext(content="A" * 200, source="docs"),
KnowledgeContext(content="B" * 200, source="docs"),
TaskContext(content="C" * 200, source="task"),
]
result = await compressor.compress_contexts(contexts, budget)
assert len(result) == 3
@pytest.mark.asyncio
async def test_strategy_selection_by_type(self) -> None:
"""Test that correct strategy is selected for each type."""
compressor = ContextCompressor()
assert compressor._get_strategy_for_type(ContextType.SYSTEM) == "end"
assert compressor._get_strategy_for_type(ContextType.TASK) == "end"
assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"
class TestTruncationEdgeCases:
"""Tests for edge cases in truncation to prevent regressions."""
@pytest.mark.asyncio
async def test_truncation_ratio_with_zero_original_tokens(self) -> None:
"""Test that truncation ratio handles zero original tokens without division by zero."""
strategy = TruncationStrategy()
# Empty content should not raise ZeroDivisionError
result = await strategy.truncate_to_tokens("", max_tokens=100)
assert result.truncation_ratio == 0.0
assert result.original_tokens == 0
@pytest.mark.asyncio
async def test_truncate_end_with_zero_available_tokens(self) -> None:
"""Test truncation when marker tokens exceed max_tokens."""
strategy = TruncationStrategy()
content = "Some content to truncate"
# max_tokens less than marker tokens should return just marker
result = await strategy.truncate_to_tokens(
content, max_tokens=1, strategy="end"
)
# Should handle gracefully without crashing
assert strategy.truncation_marker in result.content or result.content == content
@pytest.mark.asyncio
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
"""Test truncation when content estimates to zero tokens."""
strategy = TruncationStrategy()
# Very short content that might estimate to 0 tokens
result = await strategy.truncate_to_tokens("a", max_tokens=100)
# Should not raise ZeroDivisionError
assert result.content in ("a", "a" + strategy.truncation_marker)
@pytest.mark.asyncio
async def test_get_content_for_tokens_zero_target(self) -> None:
"""Test _get_content_for_tokens with zero target tokens."""
strategy = TruncationStrategy()
result = await strategy._get_content_for_tokens(
content="Some content",
target_tokens=0,
from_start=True,
)
assert result == ""
@pytest.mark.asyncio
async def test_sentence_truncation_with_no_sentences(self) -> None:
"""Test sentence truncation with content that has no sentence boundaries."""
strategy = TruncationStrategy()
content = "this is content without any sentence ending punctuation"
result = await strategy.truncate_to_tokens(
content, max_tokens=5, strategy="sentence"
)
# Should handle gracefully
assert result is not None
@pytest.mark.asyncio
async def test_middle_truncation_very_short_content(self) -> None:
"""Test middle truncation with content shorter than preserved portions."""
strategy = TruncationStrategy(preserve_ratio_start=0.7)
content = "ab" # Very short
result = await strategy.truncate_to_tokens(
content, max_tokens=1, strategy="middle"
)
# Should handle gracefully without negative indices
assert result is not None

View File

@@ -0,0 +1,456 @@
"""Tests for ContextEngine."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.context.config import ContextSettings
from app.services.context.engine import ContextEngine, create_context_engine
from app.services.context.types import (
AssembledContext,
ConversationContext,
KnowledgeContext,
MessageRole,
ToolContext,
)
class TestContextEngineCreation:
"""Tests for ContextEngine creation."""
def test_creation_minimal(self) -> None:
"""Test creating engine with minimal config."""
engine = ContextEngine()
assert engine._mcp is None
assert engine._settings is not None
assert engine._calculator is not None
assert engine._scorer is not None
assert engine._ranker is not None
assert engine._compressor is not None
assert engine._cache is not None
assert engine._pipeline is not None
def test_creation_with_settings(self) -> None:
"""Test creating engine with custom settings."""
settings = ContextSettings(
compression_threshold=0.7,
cache_enabled=False,
)
engine = ContextEngine(settings=settings)
assert engine._settings.compression_threshold == 0.7
assert engine._settings.cache_enabled is False
def test_creation_with_redis(self) -> None:
"""Test creating engine with Redis."""
mock_redis = MagicMock()
settings = ContextSettings(cache_enabled=True)
engine = ContextEngine(redis=mock_redis, settings=settings)
assert engine._cache.is_enabled
def test_set_mcp_manager(self) -> None:
"""Test setting MCP manager."""
engine = ContextEngine()
mock_mcp = MagicMock()
engine.set_mcp_manager(mock_mcp)
assert engine._mcp is mock_mcp
def test_set_redis(self) -> None:
"""Test setting Redis connection."""
engine = ContextEngine()
mock_redis = MagicMock()
engine.set_redis(mock_redis)
assert engine._cache._redis is mock_redis
class TestContextEngineHelpers:
"""Tests for ContextEngine helper methods."""
def test_convert_conversation(self) -> None:
"""Test converting conversation history."""
engine = ContextEngine()
history = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]
contexts = engine._convert_conversation(history)
assert len(contexts) == 3
assert all(isinstance(c, ConversationContext) for c in contexts)
assert contexts[0].role == MessageRole.USER
assert contexts[1].role == MessageRole.ASSISTANT
assert contexts[0].content == "Hello!"
assert contexts[0].metadata["turn"] == 0
def test_convert_tool_results(self) -> None:
"""Test converting tool results."""
engine = ContextEngine()
results = [
{"tool_name": "search", "content": "Result 1", "status": "success"},
{"tool_name": "read", "result": {"file": "test.txt"}, "status": "success"},
]
contexts = engine._convert_tool_results(results)
assert len(contexts) == 2
assert all(isinstance(c, ToolContext) for c in contexts)
assert contexts[0].content == "Result 1"
assert contexts[0].metadata["tool_name"] == "search"
# Dict content should be JSON serialized
assert "file" in contexts[1].content
assert "test.txt" in contexts[1].content
class TestContextEngineAssembly:
"""Tests for context assembly."""
@pytest.mark.asyncio
async def test_assemble_minimal(self) -> None:
"""Test assembling with minimal inputs."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test query",
model="claude-3-sonnet",
use_cache=False, # Disable cache for test
)
assert isinstance(result, AssembledContext)
assert result.context_count == 0 # No contexts provided
@pytest.mark.asyncio
async def test_assemble_with_system_prompt(self) -> None:
"""Test assembling with system prompt."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test query",
model="claude-3-sonnet",
system_prompt="You are a helpful assistant.",
use_cache=False,
)
assert result.context_count == 1
assert "helpful assistant" in result.content
@pytest.mark.asyncio
async def test_assemble_with_task(self) -> None:
"""Test assembling with task description."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="implement feature",
model="claude-3-sonnet",
task_description="Implement user authentication",
use_cache=False,
)
assert result.context_count == 1
assert "authentication" in result.content
@pytest.mark.asyncio
async def test_assemble_with_conversation(self) -> None:
"""Test assembling with conversation history."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="continue",
model="claude-3-sonnet",
conversation_history=[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi!"},
],
use_cache=False,
)
assert result.context_count == 2
assert "Hello" in result.content
@pytest.mark.asyncio
async def test_assemble_with_tool_results(self) -> None:
"""Test assembling with tool results."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="continue",
model="claude-3-sonnet",
tool_results=[
{"tool_name": "search", "content": "Found 5 results"},
],
use_cache=False,
)
assert result.context_count == 1
assert "Found 5 results" in result.content
@pytest.mark.asyncio
async def test_assemble_with_custom_contexts(self) -> None:
"""Test assembling with custom contexts."""
engine = ContextEngine()
custom = [
KnowledgeContext(
content="Custom knowledge.",
source="custom",
relevance_score=0.9,
)
]
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test",
model="claude-3-sonnet",
custom_contexts=custom,
use_cache=False,
)
assert result.context_count == 1
assert "Custom knowledge" in result.content
@pytest.mark.asyncio
async def test_assemble_full_workflow(self) -> None:
"""Test full assembly workflow."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="implement login",
model="claude-3-sonnet",
system_prompt="You are an expert Python developer.",
task_description="Implement user authentication.",
conversation_history=[
{"role": "user", "content": "Can you help me implement JWT auth?"},
],
tool_results=[
{"tool_name": "file_create", "content": "Created auth.py"},
],
use_cache=False,
)
assert result.context_count >= 4
assert result.total_tokens > 0
assert result.model == "claude-3-sonnet"
# Check for expected content
assert "expert Python developer" in result.content
assert "authentication" in result.content
class TestContextEngineKnowledge:
"""Tests for knowledge fetching."""
@pytest.mark.asyncio
async def test_fetch_knowledge_no_mcp(self) -> None:
"""Test fetching knowledge without MCP returns empty."""
engine = ContextEngine()
result = await engine._fetch_knowledge(
project_id="proj-123",
agent_id="agent-456",
query="test",
)
assert result == []
@pytest.mark.asyncio
async def test_fetch_knowledge_with_mcp(self) -> None:
"""Test fetching knowledge with MCP."""
mock_mcp = AsyncMock()
mock_mcp.call_tool.return_value.data = {
"results": [
{
"content": "Document content",
"source_path": "docs/api.md",
"score": 0.9,
"chunk_id": "chunk-1",
},
{
"content": "Another document",
"source_path": "docs/auth.md",
"score": 0.8,
},
]
}
engine = ContextEngine(mcp_manager=mock_mcp)
result = await engine._fetch_knowledge(
project_id="proj-123",
agent_id="agent-456",
query="authentication",
)
assert len(result) == 2
assert all(isinstance(c, KnowledgeContext) for c in result)
assert result[0].content == "Document content"
assert result[0].source == "docs/api.md"
assert result[0].relevance_score == 0.9
@pytest.mark.asyncio
async def test_fetch_knowledge_error_handling(self) -> None:
"""Test knowledge fetch error handling."""
mock_mcp = AsyncMock()
mock_mcp.call_tool.side_effect = Exception("MCP error")
engine = ContextEngine(mcp_manager=mock_mcp)
# Should not raise, returns empty
result = await engine._fetch_knowledge(
project_id="proj-123",
agent_id="agent-456",
query="test",
)
assert result == []
class TestContextEngineCaching:
"""Tests for caching behavior."""
@pytest.mark.asyncio
async def test_cache_disabled(self) -> None:
"""Test assembly with cache disabled."""
engine = ContextEngine()
result = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test",
model="claude-3-sonnet",
system_prompt="Test prompt",
use_cache=False,
)
assert not result.cache_hit
@pytest.mark.asyncio
async def test_cache_hit(self) -> None:
"""Test cache hit."""
mock_redis = AsyncMock()
settings = ContextSettings(cache_enabled=True)
engine = ContextEngine(redis=mock_redis, settings=settings)
# First call - cache miss
mock_redis.get.return_value = None
result1 = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test",
model="claude-3-sonnet",
system_prompt="Test prompt",
)
# Second call - mock cache hit
mock_redis.get.return_value = result1.to_json()
result2 = await engine.assemble_context(
project_id="proj-123",
agent_id="agent-456",
query="test",
model="claude-3-sonnet",
system_prompt="Test prompt",
)
assert result2.cache_hit
class TestContextEngineUtilities:
"""Tests for utility methods."""
@pytest.mark.asyncio
async def test_get_budget_for_model(self) -> None:
"""Test getting budget for model."""
engine = ContextEngine()
budget = await engine.get_budget_for_model("claude-3-sonnet")
assert budget.total > 0
assert budget.system > 0
assert budget.knowledge > 0
@pytest.mark.asyncio
async def test_get_budget_with_max_tokens(self) -> None:
"""Test getting budget with max tokens."""
engine = ContextEngine()
budget = await engine.get_budget_for_model("claude-3-sonnet", max_tokens=5000)
assert budget.total == 5000
@pytest.mark.asyncio
async def test_count_tokens(self) -> None:
"""Test token counting."""
engine = ContextEngine()
count = await engine.count_tokens("Hello world")
assert count > 0
@pytest.mark.asyncio
async def test_invalidate_cache(self) -> None:
"""Test cache invalidation."""
mock_redis = AsyncMock()
async def mock_scan_iter(match=None):
for key in ["ctx:1", "ctx:2"]:
yield key
mock_redis.scan_iter = mock_scan_iter
settings = ContextSettings(cache_enabled=True)
engine = ContextEngine(redis=mock_redis, settings=settings)
deleted = await engine.invalidate_cache(pattern="*test*")
assert deleted >= 0
@pytest.mark.asyncio
async def test_get_stats(self) -> None:
"""Test getting engine stats."""
engine = ContextEngine()
stats = await engine.get_stats()
assert "cache" in stats
assert "settings" in stats
assert "compression_threshold" in stats["settings"]
class TestCreateContextEngine:
"""Tests for factory function."""
def test_create_context_engine(self) -> None:
"""Test factory function."""
engine = create_context_engine()
assert isinstance(engine, ContextEngine)
def test_create_context_engine_with_settings(self) -> None:
"""Test factory with settings."""
settings = ContextSettings(cache_enabled=False)
engine = create_context_engine(settings=settings)
assert engine._settings.cache_enabled is False

View File

@@ -1,7 +1,5 @@
"""Tests for context management exceptions.""" """Tests for context management exceptions."""
import pytest
from app.services.context.exceptions import ( from app.services.context.exceptions import (
AssemblyTimeoutError, AssemblyTimeoutError,
BudgetExceededError, BudgetExceededError,

View File

@@ -0,0 +1,499 @@
"""Tests for context ranking module."""
import pytest
from app.services.context.budget import BudgetAllocator, TokenBudget
from app.services.context.prioritization import ContextRanker, RankingResult
from app.services.context.scoring import CompositeScorer, ScoredContext
from app.services.context.types import (
ContextPriority,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
)
class TestRankingResult:
"""Tests for RankingResult dataclass."""
def test_creation(self) -> None:
"""Test RankingResult creation."""
ctx = TaskContext(content="Test", source="task")
scored = ScoredContext(context=ctx, composite_score=0.8)
result = RankingResult(
selected=[scored],
excluded=[],
total_tokens=100,
selection_stats={"total": 1},
)
assert len(result.selected) == 1
assert result.total_tokens == 100
def test_selected_contexts_property(self) -> None:
"""Test selected_contexts property extracts contexts."""
ctx1 = TaskContext(content="Test 1", source="task")
ctx2 = TaskContext(content="Test 2", source="task")
scored1 = ScoredContext(context=ctx1, composite_score=0.8)
scored2 = ScoredContext(context=ctx2, composite_score=0.6)
result = RankingResult(
selected=[scored1, scored2],
excluded=[],
total_tokens=200,
)
selected = result.selected_contexts
assert len(selected) == 2
assert ctx1 in selected
assert ctx2 in selected
class TestContextRanker:
"""Tests for ContextRanker."""
def test_creation(self) -> None:
"""Test ranker creation."""
ranker = ContextRanker()
assert ranker._scorer is not None
assert ranker._calculator is not None
def test_creation_with_scorer(self) -> None:
"""Test ranker creation with custom scorer."""
scorer = CompositeScorer(relevance_weight=0.8)
ranker = ContextRanker(scorer=scorer)
assert ranker._scorer is scorer
@pytest.mark.asyncio
async def test_rank_empty_contexts(self) -> None:
"""Test ranking empty context list."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
result = await ranker.rank([], "query", budget)
assert len(result.selected) == 0
assert len(result.excluded) == 0
assert result.total_tokens == 0
@pytest.mark.asyncio
async def test_rank_single_context_fits(self) -> None:
"""Test ranking single context that fits budget."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
context = KnowledgeContext(
content="Short content",
source="docs",
relevance_score=0.8,
)
result = await ranker.rank([context], "query", budget)
assert len(result.selected) == 1
assert len(result.excluded) == 0
assert result.selected[0].context is context
@pytest.mark.asyncio
async def test_rank_respects_budget(self) -> None:
"""Test that ranking respects token budget."""
ranker = ContextRanker()
# Create a very small budget
budget = TokenBudget(
total=100,
knowledge=50, # Only 50 tokens for knowledge
)
# Create contexts that exceed budget
contexts = [
KnowledgeContext(
content="A" * 200, # ~50 tokens
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="B" * 200, # ~50 tokens
source="docs",
relevance_score=0.8,
),
KnowledgeContext(
content="C" * 200, # ~50 tokens
source="docs",
relevance_score=0.7,
),
]
result = await ranker.rank(contexts, "query", budget)
# Not all should fit
assert len(result.selected) < len(contexts)
assert len(result.excluded) > 0
@pytest.mark.asyncio
async def test_rank_selects_highest_scores(self) -> None:
"""Test that ranking selects highest scored contexts."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(1000)
# Small budget for knowledge
budget.knowledge = 100
contexts = [
KnowledgeContext(
content="Low score",
source="docs",
relevance_score=0.2,
),
KnowledgeContext(
content="High score",
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="Medium score",
source="docs",
relevance_score=0.5,
),
]
result = await ranker.rank(contexts, "query", budget)
# Should have selected some
if result.selected:
# The highest scored should be selected first
scores = [s.composite_score for s in result.selected]
assert scores == sorted(scores, reverse=True)
@pytest.mark.asyncio
async def test_rank_critical_priority_always_included(self) -> None:
"""Test that CRITICAL priority contexts are always included."""
ranker = ContextRanker()
# Very small budget
budget = TokenBudget(
total=100,
system=10, # Very small
knowledge=10,
)
contexts = [
SystemContext(
content="Critical system prompt that must be included",
source="system",
priority=ContextPriority.CRITICAL.value,
),
KnowledgeContext(
content="Optional knowledge",
source="docs",
relevance_score=0.9,
),
]
result = await ranker.rank(contexts, "query", budget, ensure_required=True)
# Critical context should be in selected
selected_priorities = [s.context.priority for s in result.selected]
assert ContextPriority.CRITICAL.value in selected_priorities
@pytest.mark.asyncio
async def test_rank_without_ensure_required(self) -> None:
"""Test ranking without ensuring required contexts."""
ranker = ContextRanker()
budget = TokenBudget(
total=100,
system=50,
knowledge=50,
)
contexts = [
SystemContext(
content="A" * 500, # Large content
source="system",
priority=ContextPriority.CRITICAL.value,
),
KnowledgeContext(
content="Short",
source="docs",
relevance_score=0.9,
),
]
result = await ranker.rank(contexts, "query", budget, ensure_required=False)
# Without ensure_required, CRITICAL contexts can be excluded
# if budget doesn't allow
assert len(result.selected) + len(result.excluded) == len(contexts)
@pytest.mark.asyncio
async def test_rank_selection_stats(self) -> None:
"""Test that ranking provides useful statistics."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
contexts = [
KnowledgeContext(content="Knowledge 1", source="docs", relevance_score=0.8),
KnowledgeContext(content="Knowledge 2", source="docs", relevance_score=0.6),
TaskContext(content="Task", source="task"),
]
result = await ranker.rank(contexts, "query", budget)
stats = result.selection_stats
assert "total_contexts" in stats
assert "selected_count" in stats
assert "excluded_count" in stats
assert "total_tokens" in stats
assert "by_type" in stats
@pytest.mark.asyncio
async def test_rank_simple(self) -> None:
"""Test simple ranking without budget per type."""
ranker = ContextRanker()
contexts = [
KnowledgeContext(
content="A",
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="B",
source="docs",
relevance_score=0.7,
),
KnowledgeContext(
content="C",
source="docs",
relevance_score=0.5,
),
]
result = await ranker.rank_simple(contexts, "query", max_tokens=1000)
# Should return contexts sorted by score
assert len(result) > 0
@pytest.mark.asyncio
async def test_rank_simple_respects_max_tokens(self) -> None:
"""Test that simple ranking respects max tokens."""
ranker = ContextRanker()
# Create contexts with known token counts
contexts = [
KnowledgeContext(
content="A" * 100, # ~25 tokens
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="B" * 100,
source="docs",
relevance_score=0.8,
),
KnowledgeContext(
content="C" * 100,
source="docs",
relevance_score=0.7,
),
]
# Very small limit
result = await ranker.rank_simple(contexts, "query", max_tokens=30)
# Should only fit a limited number
assert len(result) <= len(contexts)
@pytest.mark.asyncio
async def test_rank_simple_empty(self) -> None:
"""Test simple ranking with empty list."""
ranker = ContextRanker()
result = await ranker.rank_simple([], "query", max_tokens=1000)
assert result == []
@pytest.mark.asyncio
async def test_rerank_for_diversity(self) -> None:
"""Test diversity reranking."""
ranker = ContextRanker()
# Create scored contexts from same source
contexts = [
ScoredContext(
context=KnowledgeContext(
content=f"Content {i}",
source="same-source",
relevance_score=0.9 - i * 0.1,
),
composite_score=0.9 - i * 0.1,
)
for i in range(5)
]
# Limit to 2 per source
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
assert len(result) == 5
# First 2 should be from same source, rest deferred
first_two_sources = [r.context.source for r in result[:2]]
assert all(s == "same-source" for s in first_two_sources)
@pytest.mark.asyncio
async def test_rerank_for_diversity_multiple_sources(self) -> None:
"""Test diversity reranking with multiple sources."""
ranker = ContextRanker()
contexts = [
ScoredContext(
context=KnowledgeContext(
content="Source A - 1",
source="source-a",
relevance_score=0.9,
),
composite_score=0.9,
),
ScoredContext(
context=KnowledgeContext(
content="Source A - 2",
source="source-a",
relevance_score=0.8,
),
composite_score=0.8,
),
ScoredContext(
context=KnowledgeContext(
content="Source B - 1",
source="source-b",
relevance_score=0.7,
),
composite_score=0.7,
),
ScoredContext(
context=KnowledgeContext(
content="Source A - 3",
source="source-a",
relevance_score=0.6,
),
composite_score=0.6,
),
]
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
# Should not have more than 2 from source-a in first 3
source_a_in_first_3 = sum(
1 for r in result[:3] if r.context.source == "source-a"
)
assert source_a_in_first_3 <= 2
@pytest.mark.asyncio
async def test_token_counts_set(self) -> None:
"""Test that token counts are set during ranking."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
context = KnowledgeContext(
content="Test content",
source="docs",
relevance_score=0.8,
)
# Token count should be None initially
assert context.token_count is None
await ranker.rank([context], "query", budget)
# Token count should be set after ranking
assert context.token_count is not None
assert context.token_count > 0
class TestContextRankerIntegration:
"""Integration tests for context ranking."""
@pytest.mark.asyncio
async def test_full_ranking_workflow(self) -> None:
"""Test complete ranking workflow."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(10000)
# Create diverse context types
contexts = [
SystemContext(
content="You are a helpful assistant.",
source="system",
priority=ContextPriority.CRITICAL.value,
),
TaskContext(
content="Help the user with their coding question.",
source="task",
priority=ContextPriority.HIGH.value,
),
KnowledgeContext(
content="Python is a programming language.",
source="docs/python.md",
relevance_score=0.9,
),
KnowledgeContext(
content="Java is also a programming language.",
source="docs/java.md",
relevance_score=0.4,
),
ConversationContext(
content="Hello, can you help me?",
source="chat",
role=MessageRole.USER,
),
]
result = await ranker.rank(contexts, "Python help", budget)
# System (CRITICAL) should be included
selected_types = [s.context.get_type() for s in result.selected]
assert ContextType.SYSTEM in selected_types
# Stats should be populated
assert result.selection_stats["total_contexts"] == 5
@pytest.mark.asyncio
async def test_ranking_preserves_context_order_by_score(self) -> None:
"""Test that ranking orders by score correctly."""
ranker = ContextRanker()
allocator = BudgetAllocator()
budget = allocator.create_budget(100000)
contexts = [
KnowledgeContext(
content="Low",
source="docs",
relevance_score=0.2,
),
KnowledgeContext(
content="High",
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="Medium",
source="docs",
relevance_score=0.5,
),
]
result = await ranker.rank(contexts, "query", budget)
# Verify ordering is by score
scores = [s.composite_score for s in result.selected]
assert scores == sorted(scores, reverse=True)

View File

@@ -0,0 +1,760 @@
"""Tests for context scoring module."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.context.scoring import (
CompositeScorer,
PriorityScorer,
RecencyScorer,
RelevanceScorer,
ScoredContext,
)
from app.services.context.types import (
ContextPriority,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
)
class TestRelevanceScorer:
"""Tests for RelevanceScorer."""
def test_creation(self) -> None:
"""Test scorer creation."""
scorer = RelevanceScorer()
assert scorer.weight == 1.0
def test_creation_with_weight(self) -> None:
"""Test scorer creation with custom weight."""
scorer = RelevanceScorer(weight=0.5)
assert scorer.weight == 0.5
@pytest.mark.asyncio
async def test_score_with_precomputed_relevance(self) -> None:
"""Test scoring with pre-computed relevance score."""
scorer = RelevanceScorer()
# KnowledgeContext with pre-computed score
context = KnowledgeContext(
content="Test content about Python",
source="docs/python.md",
relevance_score=0.85,
)
score = await scorer.score(context, "Python programming")
assert score == 0.85
@pytest.mark.asyncio
async def test_score_with_metadata_score(self) -> None:
"""Test scoring with metadata-provided score."""
scorer = RelevanceScorer()
context = SystemContext(
content="System prompt",
source="system",
metadata={"relevance_score": 0.9},
)
score = await scorer.score(context, "anything")
assert score == 0.9
@pytest.mark.asyncio
async def test_score_fallback_to_keyword_matching(self) -> None:
"""Test fallback to keyword matching when no score available."""
scorer = RelevanceScorer(keyword_fallback_weight=0.5)
context = TaskContext(
content="Implement authentication with JWT tokens",
source="task",
)
# Query has matching keywords
score = await scorer.score(context, "JWT authentication")
assert score > 0
@pytest.mark.asyncio
async def test_keyword_matching_no_overlap(self) -> None:
"""Test keyword matching with no query overlap."""
scorer = RelevanceScorer()
context = TaskContext(
content="Implement database migration",
source="task",
)
score = await scorer.score(context, "xyz abc 123")
assert score == 0.0
@pytest.mark.asyncio
async def test_keyword_matching_full_overlap(self) -> None:
"""Test keyword matching with high overlap."""
scorer = RelevanceScorer(keyword_fallback_weight=1.0)
context = TaskContext(
content="python programming language",
source="task",
)
score = await scorer.score(context, "python programming")
# Should have high score due to keyword overlap
assert score > 0.5
@pytest.mark.asyncio
async def test_score_with_mcp_success(self) -> None:
"""Test scoring with MCP semantic similarity."""
mock_mcp = MagicMock()
mock_result = MagicMock()
mock_result.success = True
mock_result.data = {"similarity": 0.75}
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
scorer = RelevanceScorer(mcp_manager=mock_mcp)
context = TaskContext(
content="Test content",
source="task",
)
score = await scorer.score(context, "test query")
assert score == 0.75
@pytest.mark.asyncio
async def test_score_with_mcp_failure_fallback(self) -> None:
"""Test fallback when MCP fails."""
mock_mcp = MagicMock()
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
scorer = RelevanceScorer(mcp_manager=mock_mcp, keyword_fallback_weight=0.5)
context = TaskContext(
content="Python programming code",
source="task",
)
# Should fall back to keyword matching
score = await scorer.score(context, "Python code")
assert score > 0
@pytest.mark.asyncio
async def test_score_batch(self) -> None:
"""Test batch scoring."""
scorer = RelevanceScorer()
contexts = [
KnowledgeContext(content="Python", source="1", relevance_score=0.8),
KnowledgeContext(content="Java", source="2", relevance_score=0.6),
KnowledgeContext(content="Go", source="3", relevance_score=0.9),
]
scores = await scorer.score_batch(contexts, "test")
assert len(scores) == 3
assert scores[0] == 0.8
assert scores[1] == 0.6
assert scores[2] == 0.9
def test_set_mcp_manager(self) -> None:
"""Test setting MCP manager."""
scorer = RelevanceScorer()
assert scorer._mcp is None
mock_mcp = MagicMock()
scorer.set_mcp_manager(mock_mcp)
assert scorer._mcp is mock_mcp
class TestRecencyScorer:
"""Tests for RecencyScorer."""
def test_creation(self) -> None:
"""Test scorer creation."""
scorer = RecencyScorer()
assert scorer.weight == 1.0
assert scorer._half_life_hours == 24.0
def test_creation_with_custom_half_life(self) -> None:
"""Test scorer creation with custom half-life."""
scorer = RecencyScorer(half_life_hours=12.0)
assert scorer._half_life_hours == 12.0
@pytest.mark.asyncio
async def test_score_recent_context(self) -> None:
"""Test scoring a very recent context."""
scorer = RecencyScorer(half_life_hours=24.0)
now = datetime.now(UTC)
context = TaskContext(
content="Recent task",
source="task",
timestamp=now,
)
score = await scorer.score(context, "query", reference_time=now)
# Very recent should have score near 1.0
assert score > 0.99
@pytest.mark.asyncio
async def test_score_at_half_life(self) -> None:
"""Test scoring at exactly half-life age."""
scorer = RecencyScorer(half_life_hours=24.0)
now = datetime.now(UTC)
half_life_ago = now - timedelta(hours=24)
context = TaskContext(
content="Day old task",
source="task",
timestamp=half_life_ago,
)
score = await scorer.score(context, "query", reference_time=now)
# At half-life, score should be ~0.5
assert 0.49 <= score <= 0.51
@pytest.mark.asyncio
async def test_score_old_context(self) -> None:
"""Test scoring a very old context."""
scorer = RecencyScorer(half_life_hours=24.0)
now = datetime.now(UTC)
week_ago = now - timedelta(days=7)
context = TaskContext(
content="Week old task",
source="task",
timestamp=week_ago,
)
score = await scorer.score(context, "query", reference_time=now)
# 7 days with 24h half-life = very low score
assert score < 0.01
@pytest.mark.asyncio
async def test_type_specific_half_lives(self) -> None:
"""Test that different context types have different half-lives."""
scorer = RecencyScorer()
now = datetime.now(UTC)
one_hour_ago = now - timedelta(hours=1)
# Conversation has 1 hour half-life by default
conv_context = ConversationContext(
content="Hello",
source="chat",
role=MessageRole.USER,
timestamp=one_hour_ago,
)
# Knowledge has 168 hour (1 week) half-life by default
knowledge_context = KnowledgeContext(
content="Documentation",
source="docs",
timestamp=one_hour_ago,
)
conv_score = await scorer.score(conv_context, "query", reference_time=now)
knowledge_score = await scorer.score(
knowledge_context, "query", reference_time=now
)
# Conversation should decay much faster
assert conv_score < knowledge_score
def test_get_half_life(self) -> None:
"""Test getting half-life for context type."""
scorer = RecencyScorer()
assert scorer.get_half_life(ContextType.CONVERSATION) == 1.0
assert scorer.get_half_life(ContextType.KNOWLEDGE) == 168.0
assert scorer.get_half_life(ContextType.SYSTEM) == 720.0
def test_set_half_life(self) -> None:
"""Test setting custom half-life."""
scorer = RecencyScorer()
scorer.set_half_life(ContextType.TASK, 48.0)
assert scorer.get_half_life(ContextType.TASK) == 48.0
def test_set_half_life_invalid(self) -> None:
"""Test setting invalid half-life."""
scorer = RecencyScorer()
with pytest.raises(ValueError):
scorer.set_half_life(ContextType.TASK, 0)
with pytest.raises(ValueError):
scorer.set_half_life(ContextType.TASK, -1)
@pytest.mark.asyncio
async def test_score_batch(self) -> None:
"""Test batch scoring."""
scorer = RecencyScorer()
now = datetime.now(UTC)
contexts = [
TaskContext(content="1", source="t", timestamp=now),
TaskContext(content="2", source="t", timestamp=now - timedelta(hours=24)),
TaskContext(content="3", source="t", timestamp=now - timedelta(hours=48)),
]
scores = await scorer.score_batch(contexts, "query", reference_time=now)
assert len(scores) == 3
# Scores should be in descending order (more recent = higher)
assert scores[0] > scores[1] > scores[2]
class TestPriorityScorer:
"""Tests for PriorityScorer."""
def test_creation(self) -> None:
"""Test scorer creation."""
scorer = PriorityScorer()
assert scorer.weight == 1.0
@pytest.mark.asyncio
async def test_score_critical_priority(self) -> None:
"""Test scoring CRITICAL priority context."""
scorer = PriorityScorer()
context = SystemContext(
content="Critical system prompt",
source="system",
priority=ContextPriority.CRITICAL.value,
)
score = await scorer.score(context, "query")
# CRITICAL (100) + type bonus should be > 1.0, normalized to 1.0
assert score == 1.0
@pytest.mark.asyncio
async def test_score_normal_priority(self) -> None:
"""Test scoring NORMAL priority context."""
scorer = PriorityScorer()
context = TaskContext(
content="Normal task",
source="task",
priority=ContextPriority.NORMAL.value,
)
score = await scorer.score(context, "query")
# NORMAL (50) = 0.5, plus TASK bonus (0.15) = 0.65
assert 0.6 <= score <= 0.7
@pytest.mark.asyncio
async def test_score_low_priority(self) -> None:
"""Test scoring LOW priority context."""
scorer = PriorityScorer()
context = KnowledgeContext(
content="Low priority knowledge",
source="docs",
priority=ContextPriority.LOW.value,
)
score = await scorer.score(context, "query")
# LOW (20) = 0.2, no bonus for KNOWLEDGE
assert 0.15 <= score <= 0.25
@pytest.mark.asyncio
async def test_type_bonuses(self) -> None:
"""Test type-specific priority bonuses."""
scorer = PriorityScorer()
# All with same base priority
system_ctx = SystemContext(
content="System",
source="system",
priority=50,
)
task_ctx = TaskContext(
content="Task",
source="task",
priority=50,
)
knowledge_ctx = KnowledgeContext(
content="Knowledge",
source="docs",
priority=50,
)
system_score = await scorer.score(system_ctx, "query")
task_score = await scorer.score(task_ctx, "query")
knowledge_score = await scorer.score(knowledge_ctx, "query")
# System has highest bonus (0.2), task next (0.15), knowledge has none
assert system_score > task_score > knowledge_score
def test_get_type_bonus(self) -> None:
"""Test getting type bonus."""
scorer = PriorityScorer()
assert scorer.get_type_bonus(ContextType.SYSTEM) == 0.2
assert scorer.get_type_bonus(ContextType.TASK) == 0.15
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.0
def test_set_type_bonus(self) -> None:
"""Test setting custom type bonus."""
scorer = PriorityScorer()
scorer.set_type_bonus(ContextType.KNOWLEDGE, 0.1)
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.1
def test_set_type_bonus_invalid(self) -> None:
"""Test setting invalid type bonus."""
scorer = PriorityScorer()
with pytest.raises(ValueError):
scorer.set_type_bonus(ContextType.KNOWLEDGE, 1.5)
with pytest.raises(ValueError):
scorer.set_type_bonus(ContextType.KNOWLEDGE, -0.1)
class TestCompositeScorer:
"""Tests for CompositeScorer."""
def test_creation(self) -> None:
"""Test scorer creation with default weights."""
scorer = CompositeScorer()
weights = scorer.weights
assert weights["relevance"] == 0.5
assert weights["recency"] == 0.3
assert weights["priority"] == 0.2
def test_creation_with_custom_weights(self) -> None:
"""Test scorer creation with custom weights."""
scorer = CompositeScorer(
relevance_weight=0.6,
recency_weight=0.2,
priority_weight=0.2,
)
weights = scorer.weights
assert weights["relevance"] == 0.6
assert weights["recency"] == 0.2
assert weights["priority"] == 0.2
def test_update_weights(self) -> None:
"""Test updating weights."""
scorer = CompositeScorer()
scorer.update_weights(relevance=0.7, recency=0.2, priority=0.1)
weights = scorer.weights
assert weights["relevance"] == 0.7
assert weights["recency"] == 0.2
assert weights["priority"] == 0.1
def test_update_weights_partial(self) -> None:
"""Test partially updating weights."""
scorer = CompositeScorer()
original_recency = scorer.weights["recency"]
scorer.update_weights(relevance=0.8)
assert scorer.weights["relevance"] == 0.8
assert scorer.weights["recency"] == original_recency
@pytest.mark.asyncio
async def test_score_basic(self) -> None:
"""Test basic composite scoring."""
scorer = CompositeScorer()
context = KnowledgeContext(
content="Test content",
source="docs",
relevance_score=0.8,
timestamp=datetime.now(UTC),
priority=ContextPriority.NORMAL.value,
)
score = await scorer.score(context, "test query")
assert 0.0 <= score <= 1.0
@pytest.mark.asyncio
async def test_score_with_details(self) -> None:
"""Test scoring with detailed breakdown."""
scorer = CompositeScorer()
context = KnowledgeContext(
content="Test content",
source="docs",
relevance_score=0.8,
timestamp=datetime.now(UTC),
priority=ContextPriority.HIGH.value,
)
scored = await scorer.score_with_details(context, "test query")
assert isinstance(scored, ScoredContext)
assert scored.context is context
assert 0.0 <= scored.composite_score <= 1.0
assert scored.relevance_score == 0.8
assert scored.recency_score > 0.9 # Very recent
assert scored.priority_score > 0.5 # HIGH priority
@pytest.mark.asyncio
async def test_score_not_cached_on_context(self) -> None:
"""Test that scores are NOT cached on the context.
Scores should not be cached on the context because they are query-dependent.
Different queries would get incorrect cached scores if we cached on the context.
"""
scorer = CompositeScorer()
context = KnowledgeContext(
content="Test",
source="docs",
relevance_score=0.5,
)
# After scoring, context._score should remain None
# (we don't cache on context because scores are query-dependent)
await scorer.score(context, "query")
# The scorer should compute fresh scores each time
# rather than caching on the context object
# Score again with different query - should compute fresh score
score1 = await scorer.score(context, "query 1")
score2 = await scorer.score(context, "query 2")
# Both should be valid scores (not necessarily equal since queries differ)
assert 0.0 <= score1 <= 1.0
assert 0.0 <= score2 <= 1.0
@pytest.mark.asyncio
async def test_score_batch(self) -> None:
"""Test batch scoring."""
scorer = CompositeScorer()
contexts = [
KnowledgeContext(
content="High relevance",
source="docs",
relevance_score=0.9,
),
KnowledgeContext(
content="Low relevance",
source="docs",
relevance_score=0.2,
),
]
scored = await scorer.score_batch(contexts, "query")
assert len(scored) == 2
assert scored[0].relevance_score > scored[1].relevance_score
@pytest.mark.asyncio
async def test_rank(self) -> None:
"""Test ranking contexts."""
scorer = CompositeScorer()
contexts = [
KnowledgeContext(content="Low", source="docs", relevance_score=0.2),
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
KnowledgeContext(content="Medium", source="docs", relevance_score=0.5),
]
ranked = await scorer.rank(contexts, "query")
# Should be sorted by score (highest first)
assert len(ranked) == 3
assert ranked[0].relevance_score == 0.9
assert ranked[1].relevance_score == 0.5
assert ranked[2].relevance_score == 0.2
@pytest.mark.asyncio
async def test_rank_with_limit(self) -> None:
"""Test ranking with limit."""
scorer = CompositeScorer()
contexts = [
KnowledgeContext(content=str(i), source="docs", relevance_score=i / 10)
for i in range(10)
]
ranked = await scorer.rank(contexts, "query", limit=3)
assert len(ranked) == 3
@pytest.mark.asyncio
async def test_rank_with_min_score(self) -> None:
"""Test ranking with minimum score threshold."""
scorer = CompositeScorer()
contexts = [
KnowledgeContext(content="Low", source="docs", relevance_score=0.1),
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
]
ranked = await scorer.rank(contexts, "query", min_score=0.5)
# Only the high relevance context should pass the threshold
assert len(ranked) <= 2 # Could be 1 if min_score filters
def test_set_mcp_manager(self) -> None:
"""Test setting MCP manager."""
scorer = CompositeScorer()
mock_mcp = MagicMock()
scorer.set_mcp_manager(mock_mcp)
assert scorer._relevance_scorer._mcp is mock_mcp
@pytest.mark.asyncio
async def test_concurrent_scoring_same_context_no_race(self) -> None:
"""Test that concurrent scoring of the same context doesn't cause race conditions.
This verifies that the per-context locking mechanism prevents the same context
from being scored multiple times when scored concurrently.
"""
import asyncio
# Use scorer with recency_weight=0 to eliminate time-dependent variation
# (recency scores change as time passes between calls)
scorer = CompositeScorer(
relevance_weight=0.5,
recency_weight=0.0, # Disable recency to get deterministic results
priority_weight=0.5,
)
# Create a single context that will be scored multiple times concurrently
context = KnowledgeContext(
content="Test content for race condition test",
source="docs",
relevance_score=0.75,
)
# Score the same context many times in parallel
num_concurrent = 50
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
scores = await asyncio.gather(*tasks)
# All scores should be identical (deterministic scoring without recency)
assert all(s == scores[0] for s in scores)
# Note: We don't cache _score on context because scores are query-dependent
@pytest.mark.asyncio
async def test_concurrent_scoring_different_contexts(self) -> None:
"""Test that concurrent scoring of different contexts works correctly.
Different contexts should not interfere with each other during parallel scoring.
"""
import asyncio
scorer = CompositeScorer()
# Create many different contexts
contexts = [
KnowledgeContext(
content=f"Test content {i}",
source="docs",
relevance_score=i / 10,
)
for i in range(10)
]
# Score all contexts concurrently
tasks = [scorer.score(ctx, "test query") for ctx in contexts]
scores = await asyncio.gather(*tasks)
# Each context should have a different score based on its relevance
assert len(set(scores)) > 1 # Not all the same
# Note: We don't cache _score on context because scores are query-dependent
class TestScoredContext:
"""Tests for ScoredContext dataclass."""
def test_creation(self) -> None:
"""Test ScoredContext creation."""
context = TaskContext(content="Test", source="task")
scored = ScoredContext(
context=context,
composite_score=0.75,
relevance_score=0.8,
recency_score=0.7,
priority_score=0.5,
)
assert scored.context is context
assert scored.composite_score == 0.75
def test_comparison_operators(self) -> None:
"""Test comparison operators for sorting."""
ctx1 = TaskContext(content="1", source="task")
ctx2 = TaskContext(content="2", source="task")
scored1 = ScoredContext(context=ctx1, composite_score=0.5)
scored2 = ScoredContext(context=ctx2, composite_score=0.8)
assert scored1 < scored2
assert scored2 > scored1
def test_sorting(self) -> None:
"""Test sorting scored contexts."""
contexts = [
ScoredContext(
context=TaskContext(content="Low", source="task"),
composite_score=0.3,
),
ScoredContext(
context=TaskContext(content="High", source="task"),
composite_score=0.9,
),
ScoredContext(
context=TaskContext(content="Medium", source="task"),
composite_score=0.6,
),
]
sorted_contexts = sorted(contexts, reverse=True)
assert sorted_contexts[0].composite_score == 0.9
assert sorted_contexts[1].composite_score == 0.6
assert sorted_contexts[2].composite_score == 0.3
class TestBaseScorer:
"""Tests for BaseScorer abstract class."""
def test_weight_property(self) -> None:
"""Test weight property."""
# Use a concrete implementation
scorer = RelevanceScorer(weight=0.7)
assert scorer.weight == 0.7
def test_weight_setter_valid(self) -> None:
"""Test weight setter with valid values."""
scorer = RelevanceScorer()
scorer.weight = 0.5
assert scorer.weight == 0.5
def test_weight_setter_invalid(self) -> None:
"""Test weight setter with invalid values."""
scorer = RelevanceScorer()
with pytest.raises(ValueError):
scorer.weight = -0.1
with pytest.raises(ValueError):
scorer.weight = 1.5
def test_normalize_score(self) -> None:
"""Test score normalization."""
scorer = RelevanceScorer()
# Normal range
assert scorer.normalize_score(0.5) == 0.5
# Below 0
assert scorer.normalize_score(-0.5) == 0.0
# Above 1
assert scorer.normalize_score(1.5) == 1.0
# Boundaries
assert scorer.normalize_score(0.0) == 0.0
assert scorer.normalize_score(1.0) == 1.0

View File

@@ -1,20 +1,17 @@
"""Tests for context types.""" """Tests for context types."""
import json
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
import pytest import pytest
from app.services.context.types import ( from app.services.context.types import (
AssembledContext, AssembledContext,
BaseContext,
ContextPriority, ContextPriority,
ContextType, ContextType,
ConversationContext, ConversationContext,
KnowledgeContext, KnowledgeContext,
MessageRole, MessageRole,
SystemContext, SystemContext,
TaskComplexity,
TaskContext, TaskContext,
TaskStatus, TaskStatus,
ToolContext, ToolContext,
@@ -181,24 +178,16 @@ class TestKnowledgeContext:
def test_is_code(self) -> None: def test_is_code(self) -> None:
"""Test is_code method.""" """Test is_code method."""
code_ctx = KnowledgeContext( code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
content="code", source="test", file_type="python" doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
)
doc_ctx = KnowledgeContext(
content="docs", source="test", file_type="markdown"
)
assert code_ctx.is_code() is True assert code_ctx.is_code() is True
assert doc_ctx.is_code() is False assert doc_ctx.is_code() is False
def test_is_documentation(self) -> None: def test_is_documentation(self) -> None:
"""Test is_documentation method.""" """Test is_documentation method."""
doc_ctx = KnowledgeContext( doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
content="docs", source="test", file_type="markdown" code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
)
code_ctx = KnowledgeContext(
content="code", source="test", file_type="python"
)
assert doc_ctx.is_documentation() is True assert doc_ctx.is_documentation() is True
assert code_ctx.is_documentation() is False assert code_ctx.is_documentation() is False
@@ -286,9 +275,20 @@ class TestTaskContext:
assert ctx.title == "Login Feature" assert ctx.title == "Login Feature"
assert ctx.get_type() == ContextType.TASK assert ctx.get_type() == ContextType.TASK
def test_default_high_priority(self) -> None: def test_default_normal_priority(self) -> None:
"""Test that task context defaults to high priority.""" """Test that task context uses NORMAL priority from base class."""
ctx = TaskContext(content="Test", source="test") ctx = TaskContext(content="Test", source="test")
# TaskContext inherits NORMAL priority from BaseContext
# Use TaskContext.create() for default HIGH priority behavior
assert ctx.priority == ContextPriority.NORMAL.value
def test_explicit_high_priority(self) -> None:
"""Test setting explicit HIGH priority."""
ctx = TaskContext(
content="Test",
source="test",
priority=ContextPriority.HIGH.value,
)
assert ctx.priority == ContextPriority.HIGH.value assert ctx.priority == ContextPriority.HIGH.value
def test_create_factory(self) -> None: def test_create_factory(self) -> None:
@@ -322,15 +322,11 @@ class TestTaskContext:
def test_status_checks(self) -> None: def test_status_checks(self) -> None:
"""Test status check methods.""" """Test status check methods."""
pending = TaskContext( pending = TaskContext(content="test", source="test", status=TaskStatus.PENDING)
content="test", source="test", status=TaskStatus.PENDING
)
completed = TaskContext( completed = TaskContext(
content="test", source="test", status=TaskStatus.COMPLETED content="test", source="test", status=TaskStatus.COMPLETED
) )
blocked = TaskContext( blocked = TaskContext(content="test", source="test", status=TaskStatus.BLOCKED)
content="test", source="test", status=TaskStatus.BLOCKED
)
assert pending.is_active() is True assert pending.is_active() is True
assert completed.is_complete() is True assert completed.is_complete() is True
@@ -384,12 +380,8 @@ class TestToolContext:
def test_is_successful(self) -> None: def test_is_successful(self) -> None:
"""Test is_successful method.""" """Test is_successful method."""
success = ToolContext.from_tool_result( success = ToolContext.from_tool_result("test", "ok", ToolResultStatus.SUCCESS)
"test", "ok", ToolResultStatus.SUCCESS error = ToolContext.from_tool_result("test", "error", ToolResultStatus.ERROR)
)
error = ToolContext.from_tool_result(
"test", "error", ToolResultStatus.ERROR
)
assert success.is_successful() is True assert success.is_successful() is True
assert error.is_successful() is False assert error.is_successful() is False
@@ -415,11 +407,14 @@ class TestAssembledContext:
"""Test basic creation.""" """Test basic creation."""
ctx = AssembledContext( ctx = AssembledContext(
content="Assembled content here", content="Assembled content here",
token_count=500, total_tokens=500,
contexts_included=5, context_count=5,
) )
assert ctx.content == "Assembled content here" assert ctx.content == "Assembled content here"
assert ctx.total_tokens == 500
assert ctx.context_count == 5
# Test backward compatibility aliases
assert ctx.token_count == 500 assert ctx.token_count == 500
assert ctx.contexts_included == 5 assert ctx.contexts_included == 5
@@ -427,8 +422,8 @@ class TestAssembledContext:
"""Test budget_utilization property.""" """Test budget_utilization property."""
ctx = AssembledContext( ctx = AssembledContext(
content="test", content="test",
token_count=800, total_tokens=800,
contexts_included=5, context_count=5,
budget_total=1000, budget_total=1000,
budget_used=800, budget_used=800,
) )
@@ -439,8 +434,8 @@ class TestAssembledContext:
"""Test budget_utilization with zero budget.""" """Test budget_utilization with zero budget."""
ctx = AssembledContext( ctx = AssembledContext(
content="test", content="test",
token_count=0, total_tokens=0,
contexts_included=0, context_count=0,
budget_total=0, budget_total=0,
budget_used=0, budget_used=0,
) )
@@ -451,24 +446,26 @@ class TestAssembledContext:
"""Test to_dict method.""" """Test to_dict method."""
ctx = AssembledContext( ctx = AssembledContext(
content="test", content="test",
token_count=100, total_tokens=100,
contexts_included=2, context_count=2,
assembly_time_ms=50.123, assembly_time_ms=50.123,
) )
data = ctx.to_dict() data = ctx.to_dict()
assert data["content"] == "test" assert data["content"] == "test"
assert data["token_count"] == 100 assert data["total_tokens"] == 100
assert data["context_count"] == 2
assert data["assembly_time_ms"] == 50.12 # Rounded assert data["assembly_time_ms"] == 50.12 # Rounded
def test_to_json_and_from_json(self) -> None: def test_to_json_and_from_json(self) -> None:
"""Test JSON serialization round-trip.""" """Test JSON serialization round-trip."""
original = AssembledContext( original = AssembledContext(
content="Test content", content="Test content",
token_count=100, total_tokens=100,
contexts_included=3, context_count=3,
contexts_excluded=2, excluded_count=2,
assembly_time_ms=45.5, assembly_time_ms=45.5,
model="claude-3-sonnet",
budget_total=1000, budget_total=1000,
budget_used=100, budget_used=100,
by_type={"system": 20, "knowledge": 80}, by_type={"system": 20, "knowledge": 80},
@@ -480,8 +477,10 @@ class TestAssembledContext:
restored = AssembledContext.from_json(json_str) restored = AssembledContext.from_json(json_str)
assert restored.content == original.content assert restored.content == original.content
assert restored.token_count == original.token_count assert restored.total_tokens == original.total_tokens
assert restored.contexts_included == original.contexts_included assert restored.context_count == original.context_count
assert restored.excluded_count == original.excluded_count
assert restored.model == original.model
assert restored.cache_hit == original.cache_hit assert restored.cache_hit == original.cache_hit
assert restored.cache_key == original.cache_key assert restored.cache_key == original.cache_key
@@ -492,9 +491,7 @@ class TestBaseContextMethods:
def test_get_age_seconds(self) -> None: def test_get_age_seconds(self) -> None:
"""Test get_age_seconds method.""" """Test get_age_seconds method."""
old_time = datetime.now(UTC) - timedelta(hours=2) old_time = datetime.now(UTC) - timedelta(hours=2)
ctx = SystemContext( ctx = SystemContext(content="test", source="test", timestamp=old_time)
content="test", source="test", timestamp=old_time
)
age = ctx.get_age_seconds() age = ctx.get_age_seconds()
# Should be approximately 2 hours in seconds # Should be approximately 2 hours in seconds
@@ -503,9 +500,7 @@ class TestBaseContextMethods:
def test_get_age_hours(self) -> None: def test_get_age_hours(self) -> None:
"""Test get_age_hours method.""" """Test get_age_hours method."""
old_time = datetime.now(UTC) - timedelta(hours=5) old_time = datetime.now(UTC) - timedelta(hours=5)
ctx = SystemContext( ctx = SystemContext(content="test", source="test", timestamp=old_time)
content="test", source="test", timestamp=old_time
)
age = ctx.get_age_hours() age = ctx.get_age_hours()
assert 4.9 < age < 5.1 assert 4.9 < age < 5.1
@@ -515,12 +510,8 @@ class TestBaseContextMethods:
old_time = datetime.now(UTC) - timedelta(days=10) old_time = datetime.now(UTC) - timedelta(days=10)
new_time = datetime.now(UTC) - timedelta(hours=1) new_time = datetime.now(UTC) - timedelta(hours=1)
old_ctx = SystemContext( old_ctx = SystemContext(content="test", source="test", timestamp=old_time)
content="test", source="test", timestamp=old_time new_ctx = SystemContext(content="test", source="test", timestamp=new_time)
)
new_ctx = SystemContext(
content="test", source="test", timestamp=new_time
)
# Default max_age is 168 hours (7 days) # Default max_age is 168 hours (7 days)
assert old_ctx.is_stale() is True assert old_ctx.is_stale() is True