chore(context): refactor for consistency, optimize formatting, and simplify logic

- Cleaned up unnecessary comments in `__all__` definitions for better readability.
- Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping).
- Simplified conditional expressions and inline comments for context scoring and ranking.
- Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`).
- Removed unused imports and ensured consistent usage across test files.
- Updated `test_score_not_cached_on_context` to clarify caching behavior.
- Improved truncation strategy logic and marker handling.
This commit is contained in:
2026-01-04 15:23:14 +01:00
parent 9e54f16e56
commit 2bea057fb1
26 changed files with 226 additions and 273 deletions

View File

@@ -124,71 +124,55 @@ from .types import (
)
__all__ = [
# Adapters
"ClaudeAdapter",
"DefaultAdapter",
"get_adapter",
"ModelAdapter",
"OpenAIAdapter",
# Assembly
"ContextPipeline",
"PipelineMetrics",
# Budget Management
"BudgetAllocator",
"TokenBudget",
"TokenCalculator",
# Cache
"ContextCache",
# Engine
"ContextEngine",
"create_context_engine",
# Compression
"ContextCompressor",
"TruncationResult",
"TruncationStrategy",
# Configuration
"ContextSettings",
"get_context_settings",
"get_default_settings",
"reset_context_settings",
# Exceptions
"AssembledContext",
"AssemblyTimeoutError",
"BaseContext",
"BaseScorer",
"BudgetAllocator",
"BudgetExceededError",
"CacheError",
"ClaudeAdapter",
"CompositeScorer",
"CompressionError",
"ContextCache",
"ContextCompressor",
"ContextEngine",
"ContextError",
"ContextNotFoundError",
"ContextPipeline",
"ContextPriority",
"ContextRanker",
"ContextSettings",
"ContextType",
"ConversationContext",
"DefaultAdapter",
"FormattingError",
"InvalidContextError",
"ScoringError",
"TokenCountError",
# Prioritization
"ContextRanker",
"RankingResult",
# Scoring
"BaseScorer",
"CompositeScorer",
"KnowledgeContext",
"MessageRole",
"ModelAdapter",
"OpenAIAdapter",
"PipelineMetrics",
"PriorityScorer",
"RankingResult",
"RecencyScorer",
"RelevanceScorer",
"ScoredContext",
# Types - Base
"AssembledContext",
"BaseContext",
"ContextPriority",
"ContextType",
# Types - Conversation
"ConversationContext",
"MessageRole",
# Types - Knowledge
"KnowledgeContext",
# Types - System
"ScoringError",
"SystemContext",
# Types - Task
"TaskComplexity",
"TaskContext",
"TaskStatus",
# Types - Tool
"TokenBudget",
"TokenCalculator",
"TokenCountError",
"ToolContext",
"ToolResultStatus",
"TruncationResult",
"TruncationStrategy",
"create_context_engine",
"get_adapter",
"get_context_settings",
"get_default_settings",
"reset_context_settings",
]

View File

@@ -5,7 +5,7 @@ Abstract base class for model-specific context formatting.
"""
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
@@ -19,7 +19,7 @@ class ModelAdapter(ABC):
"""
# Model name patterns this adapter handles
MODEL_PATTERNS: list[str] = []
MODEL_PATTERNS: ClassVar[list[str]] = []
@classmethod
def matches_model(cls, model: str) -> bool:
@@ -125,7 +125,7 @@ class DefaultAdapter(ModelAdapter):
Uses simple plain-text formatting with minimal structure.
"""
MODEL_PATTERNS: list[str] = [] # Fallback adapter
MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
@classmethod
def matches_model(cls, model: str) -> bool:

View File

@@ -5,7 +5,7 @@ Provides Claude-specific context formatting using XML tags
which Claude models understand natively.
"""
from typing import Any
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import ModelAdapter
@@ -25,7 +25,7 @@ class ClaudeAdapter(ModelAdapter):
- Tool result wrapping with tool names
"""
MODEL_PATTERNS: list[str] = ["claude", "anthropic"]
MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
def format(
self,

View File

@@ -5,7 +5,7 @@ Provides OpenAI-specific context formatting using markdown
which GPT models understand well.
"""
from typing import Any
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import ModelAdapter
@@ -25,7 +25,7 @@ class OpenAIAdapter(ModelAdapter):
- Code blocks for tool outputs
"""
MODEL_PATTERNS: list[str] = ["gpt", "openai", "o1", "o3"]
MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
def format(
self,

View File

@@ -102,9 +102,7 @@ class ContextPipeline:
self._ranker = ranker or ContextRanker(
scorer=self._scorer, calculator=self._calculator
)
self._compressor = compressor or ContextCompressor(
calculator=self._calculator
)
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
self._allocator = BudgetAllocator(self._settings)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
@@ -336,27 +334,21 @@ class ContextPipeline:
return "\n".join(c.content for c in contexts)
def _format_system(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<system_instructions>\n{content}\n</system_instructions>"
return content
def _format_task(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<current_task>\n{content}\n</current_task>"
return f"## Current Task\n\n{content}"
def _format_knowledge(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format knowledge contexts."""
if use_xml:
parts = ["<reference_documents>"]
@@ -374,9 +366,7 @@ class ContextPipeline:
parts.append("")
return "\n".join(parts)
def _format_conversation(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format conversation contexts."""
if use_xml:
parts = ["<conversation_history>"]
@@ -394,9 +384,7 @@ class ContextPipeline:
parts.append(f"**{role.upper()}**: {ctx.content}")
return "\n\n".join(parts)
def _format_tool(
self, contexts: list[BaseContext], use_xml: bool
) -> str:
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format tool contexts."""
if use_xml:
parts = ["<tool_results>"]

View File

@@ -215,9 +215,7 @@ class TokenBudget:
"buffer": self.buffer,
},
"used": dict(self.used),
"remaining": {
ct.value: self.remaining(ct) for ct in ContextType
},
"remaining": {ct.value: self.remaining(ct) for ct in ContextType},
"total_used": self.total_used(),
"total_remaining": self.total_remaining(),
"utilization": round(self.utilization(), 3),
@@ -348,13 +346,11 @@ class BudgetAllocator:
# Calculate total reclaimable (excluding prioritized types)
prioritize_values = {ct.value for ct in prioritize}
reclaimable = sum(
tokens for ct, tokens in unused.items()
if ct not in prioritize_values
tokens for ct, tokens in unused.items() if ct not in prioritize_values
)
# Redistribute to prioritized types that are near capacity
for ct in prioritize:
ct_value = ct.value
utilization = budget.utilization(ct)
if utilization > 0.8: # Near capacity

View File

@@ -7,7 +7,7 @@ Integrates with LLM Gateway for accurate counts.
import hashlib
import logging
from typing import TYPE_CHECKING, Any, Protocol
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
@@ -42,10 +42,10 @@ class TokenCalculator:
"""
# Default characters per token ratio for estimation
DEFAULT_CHARS_PER_TOKEN = 4.0
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
# Model-specific ratios (more accurate estimation)
MODEL_CHAR_RATIOS: dict[str, float] = {
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
"claude": 3.5,
"gpt-4": 4.0,
"gpt-3.5": 4.0,

View File

@@ -116,12 +116,16 @@ class ContextCache:
# This avoids JSON serializing potentially large content strings
context_data = []
for ctx in contexts:
context_data.append({
"type": ctx.get_type().value,
"content_hash": self._hash_content(ctx.content), # Hash instead of full content
"source": ctx.source,
"priority": ctx.priority, # Already an int
})
context_data.append(
{
"type": ctx.get_type().value,
"content_hash": self._hash_content(
ctx.content
), # Hash instead of full content
"source": ctx.source,
"priority": ctx.priority, # Already an int
}
)
data = {
"contexts": context_data,
@@ -412,7 +416,7 @@ class ContextCache:
# Get Redis info
info = await self._redis.info("memory") # type: ignore
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
except Exception:
pass
except Exception as e:
logger.debug(f"Failed to get Redis stats: {e}")
return stats

View File

@@ -78,7 +78,7 @@ class TruncationStrategy:
)
@property
def TRUNCATION_MARKER(self) -> str:
def truncation_marker(self) -> str:
"""Get truncation marker from settings."""
return self._settings.truncation_marker
@@ -141,7 +141,9 @@ class TruncationStrategy:
truncated_tokens=truncated_tokens,
content=truncated,
truncated=True,
truncation_ratio=0.0 if original_tokens == 0 else 1 - (truncated_tokens / original_tokens),
truncation_ratio=0.0
if original_tokens == 0
else 1 - (truncated_tokens / original_tokens),
)
async def _truncate_end(
@@ -156,17 +158,17 @@ class TruncationStrategy:
Simple but effective for most content types.
"""
# Binary search for optimal truncation point
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
marker_tokens = await self._count_tokens(self.truncation_marker, model)
available_tokens = max(0, max_tokens - marker_tokens)
# Edge case: if no tokens available for content, return just the marker
if available_tokens <= 0:
return self.TRUNCATION_MARKER
return self.truncation_marker
# Estimate characters per token (guard against division by zero)
content_tokens = await self._count_tokens(content, model)
if content_tokens == 0:
return content + self.TRUNCATION_MARKER
return content + self.truncation_marker
chars_per_token = len(content) / content_tokens
# Start with estimated position
@@ -188,7 +190,7 @@ class TruncationStrategy:
else:
high = mid - 1
return best + self.TRUNCATION_MARKER
return best + self.truncation_marker
async def _truncate_middle(
self,
@@ -201,7 +203,7 @@ class TruncationStrategy:
Good for code or content where context at boundaries matters.
"""
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
marker_tokens = await self._count_tokens(self.truncation_marker, model)
available_tokens = max_tokens - marker_tokens
# Split between start and end
@@ -218,7 +220,7 @@ class TruncationStrategy:
content, end_tokens, from_start=False, model=model
)
return start_content + self.TRUNCATION_MARKER + end_content
return start_content + self.truncation_marker + end_content
async def _truncate_sentence(
self,
@@ -236,7 +238,7 @@ class TruncationStrategy:
result: list[str] = []
total_tokens = 0
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
marker_tokens = await self._count_tokens(self.truncation_marker, model)
available = max_tokens - marker_tokens
for sentence in sentences:
@@ -248,7 +250,7 @@ class TruncationStrategy:
break
if len(result) < len(sentences):
return " ".join(result) + self.TRUNCATION_MARKER
return " ".join(result) + self.truncation_marker
return " ".join(result)
async def _get_content_for_tokens(

View File

@@ -78,12 +78,8 @@ class ContextEngine:
# Initialize components
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
self._scorer = CompositeScorer(
mcp_manager=mcp_manager, settings=self._settings
)
self._ranker = ContextRanker(
scorer=self._scorer, calculator=self._calculator
)
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
self._compressor = ContextCompressor(calculator=self._calculator)
self._allocator = BudgetAllocator(self._settings)
self._cache = ContextCache(redis=redis, settings=self._settings)
@@ -274,8 +270,19 @@ class ContextEngine:
},
)
# Check both ToolResult.success AND response success
if not result.success:
logger.warning(f"Knowledge search failed: {result.error}")
return []
if not isinstance(result.data, dict) or not result.data.get(
"success", True
):
logger.warning("Knowledge search returned unsuccessful response")
return []
contexts = []
results = result.data.get("results", []) if isinstance(result.data, dict) else []
results = result.data.get("results", [])
for chunk in results:
contexts.append(
KnowledgeContext(
@@ -283,7 +290,9 @@ class ContextEngine:
source=chunk.get("source_path", "unknown"),
relevance_score=chunk.get("score", 0.0),
metadata={
"chunk_id": chunk.get("chunk_id"),
"chunk_id": chunk.get(
"id"
), # Server returns 'id' not 'chunk_id'
"document_id": chunk.get("document_id"),
},
)
@@ -312,7 +321,9 @@ class ContextEngine:
contexts = []
for i, turn in enumerate(history):
role_str = turn.get("role", "user").lower()
role = MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
role = (
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
)
contexts.append(
ConversationContext(
@@ -346,6 +357,7 @@ class ContextEngine:
# Handle dict content
if isinstance(content, dict):
import json
content = json.dumps(content, indent=2)
contexts.append(

View File

@@ -61,7 +61,7 @@ class BudgetExceededError(ContextError):
requested: Tokens requested
context_type: Type of context that exceeded budget
"""
details = {
details: dict[str, Any] = {
"allocated": allocated,
"requested": requested,
"overage": requested - allocated,
@@ -170,7 +170,7 @@ class AssemblyTimeoutError(ContextError):
elapsed_ms: Actual elapsed time in milliseconds
stage: Pipeline stage where timeout occurred
"""
details = {
details: dict[str, Any] = {
"timeout_ms": timeout_ms,
"elapsed_ms": round(elapsed_ms, 2),
}

View File

@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any
from ..budget import TokenBudget, TokenCalculator
from ..config import ContextSettings, get_context_settings
from ..scoring.composite import CompositeScorer, ScoredContext
from ..types import BaseContext
from ..types import BaseContext, ContextPriority
if TYPE_CHECKING:
pass
@@ -111,8 +111,8 @@ class ContextRanker:
if ensure_required:
for sc in scored_contexts:
# CRITICAL priority (100) contexts are always included
if sc.context.priority >= 100:
# CRITICAL priority (150) contexts are always included
if sc.context.priority >= ContextPriority.CRITICAL.value:
required.append(sc)
else:
optional.append(sc)
@@ -239,9 +239,7 @@ class ContextRanker:
import asyncio
# Find contexts needing counts
contexts_needing_counts = [
ctx for ctx in contexts if ctx.token_count is None
]
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
if not contexts_needing_counts:
return
@@ -254,7 +252,7 @@ class ContextRanker:
counts = await asyncio.gather(*tasks)
# Assign counts back
for ctx, count in zip(contexts_needing_counts, counts):
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
ctx.token_count = count
def _count_by_type(

View File

@@ -92,7 +92,9 @@ class CompositeScorer:
# Per-context locks to prevent race conditions during parallel scoring
# Uses WeakValueDictionary so locks are garbage collected when not in use
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = (
WeakValueDictionary()
)
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
@@ -207,17 +209,14 @@ class CompositeScorer:
ScoredContext with all scores
"""
# Get lock for this specific context to prevent race conditions
# within concurrent scoring operations for the same query
context_lock = await self._get_context_lock(context.id)
async with context_lock:
# Check if context already has a score (inside lock to prevent races)
if context._score is not None:
return ScoredContext(
context=context,
composite_score=context._score,
)
# Compute individual scores in parallel
# Note: We do NOT cache scores on the context because scores are
# query-dependent. Caching without considering the query would
# return incorrect scores for different queries.
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
recency_task = self._recency_scorer.score(context, query, **kwargs)
priority_task = self._priority_scorer.score(context, query, **kwargs)
@@ -240,9 +239,6 @@ class CompositeScorer:
else:
composite = 0.0
# Cache the score on the context (now safe - inside lock)
context._score = composite
return ScoredContext(
context=context,
composite_score=composite,
@@ -271,9 +267,7 @@ class CompositeScorer:
List of ScoredContext (same order as input)
"""
if parallel:
tasks = [
self.score_with_details(ctx, query, **kwargs) for ctx in contexts
]
tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts]
return await asyncio.gather(*tasks)
else:
results = []

View File

@@ -4,7 +4,7 @@ Priority Scorer for Context Management.
Scores context based on assigned priority levels.
"""
from typing import Any
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import BaseScorer
@@ -19,11 +19,11 @@ class PriorityScorer(BaseScorer):
"""
# Default priority bonuses by context type
DEFAULT_TYPE_BONUSES: dict[ContextType, float] = {
ContextType.SYSTEM: 0.2, # System prompts get a boost
ContextType.TASK: 0.15, # Current task is important
ContextType.TOOL: 0.1, # Recent tool results matter
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
DEFAULT_TYPE_BONUSES: ClassVar[dict[ContextType, float]] = {
ContextType.SYSTEM: 0.2, # System prompts get a boost
ContextType.TASK: 0.15, # Current task is important
ContextType.TOOL: 0.1, # Recent tool results matter
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
}

View File

@@ -85,7 +85,10 @@ class RelevanceScorer(BaseScorer):
Relevance score between 0.0 and 1.0
"""
# 1. Check for pre-computed relevance score
if isinstance(context, KnowledgeContext) and context.relevance_score is not None:
if (
isinstance(context, KnowledgeContext)
and context.relevance_score is not None
):
return self.normalize_score(context.relevance_score)
# 2. Check metadata for score
@@ -95,14 +98,19 @@ class RelevanceScorer(BaseScorer):
if "score" in context.metadata:
return self.normalize_score(context.metadata["score"])
# 3. Try MCP-based semantic similarity
# 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
# Note: This requires the knowledge-base MCP server to implement compute_similarity
if self._mcp is not None:
try:
score = await self._compute_semantic_similarity(context, query)
if score is not None:
return score
except Exception as e:
logger.debug(f"Semantic scoring failed, using fallback: {e}")
# Log at debug level since this is expected if compute_similarity
# tool is not implemented in the Knowledge Base server
logger.debug(
f"Semantic scoring unavailable, using keyword fallback: {e}"
)
# 4. Fall back to keyword matching
return self._compute_keyword_score(context, query)
@@ -122,6 +130,9 @@ class RelevanceScorer(BaseScorer):
Returns:
Similarity score or None if unavailable
"""
if self._mcp is None:
return None
try:
# Use Knowledge Base's search capability to compute similarity
result = await self._mcp.call_tool(
@@ -129,7 +140,9 @@ class RelevanceScorer(BaseScorer):
tool="compute_similarity",
args={
"text1": query,
"text2": context.content[: self._semantic_max_chars], # Limit content length
"text2": context.content[
: self._semantic_max_chars
], # Limit content length
},
)

View File

@@ -27,23 +27,17 @@ from .tool import (
)
__all__ = [
# Base types
"AssembledContext",
"BaseContext",
"ContextPriority",
"ContextType",
# Conversation
"ConversationContext",
"MessageRole",
# Knowledge
"KnowledgeContext",
# System
"MessageRole",
"SystemContext",
# Task
"TaskComplexity",
"TaskContext",
"TaskStatus",
# Tool
"ToolContext",
"ToolResultStatus",
]

View File

@@ -120,7 +120,16 @@ class KnowledgeContext(BaseContext):
def is_code(self) -> bool:
"""Check if this is code content."""
code_types = {"python", "javascript", "typescript", "go", "rust", "java", "c", "cpp"}
code_types = {
"python",
"javascript",
"typescript",
"go",
"rust",
"java",
"c",
"cpp",
}
return self.file_type is not None and self.file_type.lower() in code_types
def is_documentation(self) -> bool:

View File

@@ -56,7 +56,9 @@ class ToolContext(BaseContext):
"tool_name": self.tool_name,
"tool_description": self.tool_description,
"is_result": self.is_result,
"result_status": self.result_status.value if self.result_status else None,
"result_status": self.result_status.value
if self.result_status
else None,
"execution_time_ms": self.execution_time_ms,
"parameters": self.parameters,
"server_name": self.server_name,
@@ -174,7 +176,9 @@ class ToolContext(BaseContext):
return cls(
content=content,
source=f"tool_result:{server_name}:{tool_name}" if server_name else f"tool_result:{tool_name}",
source=f"tool_result:{server_name}:{tool_name}"
if server_name
else f"tool_result:{tool_name}",
tool_name=tool_name,
is_result=True,
result_status=status,

View File

@@ -1,11 +1,8 @@
"""Tests for model adapters."""
import pytest
from app.services.context.adapters import (
ClaudeAdapter,
DefaultAdapter,
ModelAdapter,
OpenAIAdapter,
get_adapter,
)

View File

@@ -5,10 +5,9 @@ from datetime import UTC, datetime
import pytest
from app.services.context.assembly import ContextPipeline, PipelineMetrics
from app.services.context.budget import BudgetAllocator, TokenBudget
from app.services.context.budget import TokenBudget
from app.services.context.types import (
AssembledContext,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
@@ -354,7 +353,10 @@ class TestContextPipelineFormatting:
if result.context_count > 0:
assert "<conversation_history>" in result.content
assert '<message role="user">' in result.content or 'role="user"' in result.content
assert (
'<message role="user">' in result.content
or 'role="user"' in result.content
)
@pytest.mark.asyncio
async def test_format_tool_results(self) -> None:
@@ -474,6 +476,10 @@ class TestContextPipelineIntegration:
assert system_pos < task_pos
if task_pos >= 0 and knowledge_pos >= 0:
assert task_pos < knowledge_pos
if knowledge_pos >= 0 and conversation_pos >= 0:
assert knowledge_pos < conversation_pos
if conversation_pos >= 0 and tool_pos >= 0:
assert conversation_pos < tool_pos
@pytest.mark.asyncio
async def test_excluded_contexts_tracked(self) -> None:

View File

@@ -2,16 +2,15 @@
import pytest
from app.services.context.budget import BudgetAllocator
from app.services.context.compression import (
ContextCompressor,
TruncationResult,
TruncationStrategy,
)
from app.services.context.budget import BudgetAllocator, TokenBudget
from app.services.context.types import (
ContextType,
KnowledgeContext,
SystemContext,
TaskContext,
)
@@ -113,7 +112,7 @@ class TestTruncationStrategy:
assert result.truncated is True
assert len(result.content) < len(content)
assert strategy.TRUNCATION_MARKER in result.content
assert strategy.truncation_marker in result.content
@pytest.mark.asyncio
async def test_truncate_middle_strategy(self) -> None:
@@ -126,7 +125,7 @@ class TestTruncationStrategy:
)
assert result.truncated is True
assert strategy.TRUNCATION_MARKER in result.content
assert strategy.truncation_marker in result.content
@pytest.mark.asyncio
async def test_truncate_sentence_strategy(self) -> None:
@@ -140,7 +139,9 @@ class TestTruncationStrategy:
assert result.truncated is True
# Should cut at sentence boundary
assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content
assert (
result.content.endswith(".") or strategy.truncation_marker in result.content
)
class TestContextCompressor:
@@ -235,10 +236,12 @@ class TestTruncationEdgeCases:
content = "Some content to truncate"
# max_tokens less than marker tokens should return just marker
result = await strategy.truncate_to_tokens(content, max_tokens=1, strategy="end")
result = await strategy.truncate_to_tokens(
content, max_tokens=1, strategy="end"
)
# Should handle gracefully without crashing
assert strategy.TRUNCATION_MARKER in result.content or result.content == content
assert strategy.truncation_marker in result.content or result.content == content
@pytest.mark.asyncio
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
@@ -249,7 +252,7 @@ class TestTruncationEdgeCases:
result = await strategy.truncate_to_tokens("a", max_tokens=100)
# Should not raise ZeroDivisionError
assert result.content in ("a", "a" + strategy.TRUNCATION_MARKER)
assert result.content in ("a", "a" + strategy.truncation_marker)
@pytest.mark.asyncio
async def test_get_content_for_tokens_zero_target(self) -> None:

View File

@@ -11,8 +11,6 @@ from app.services.context.types import (
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)

View File

@@ -1,7 +1,5 @@
"""Tests for context management exceptions."""
import pytest
from app.services.context.exceptions import (
AssemblyTimeoutError,
BudgetExceededError,

View File

@@ -1,7 +1,5 @@
"""Tests for context ranking module."""
from datetime import UTC, datetime
import pytest
from app.services.context.budget import BudgetAllocator, TokenBudget
@@ -230,9 +228,7 @@ class TestContextRanker:
),
]
result = await ranker.rank(
contexts, "query", budget, ensure_required=False
)
result = await ranker.rank(contexts, "query", budget, ensure_required=False)
# Without ensure_required, CRITICAL contexts can be excluded
# if budget doesn't allow
@@ -246,12 +242,8 @@ class TestContextRanker:
budget = allocator.create_budget(10000)
contexts = [
KnowledgeContext(
content="Knowledge 1", source="docs", relevance_score=0.8
),
KnowledgeContext(
content="Knowledge 2", source="docs", relevance_score=0.6
),
KnowledgeContext(content="Knowledge 1", source="docs", relevance_score=0.8),
KnowledgeContext(content="Knowledge 2", source="docs", relevance_score=0.6),
TaskContext(content="Task", source="task"),
]

View File

@@ -6,7 +6,6 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.context.scoring import (
BaseScorer,
CompositeScorer,
PriorityScorer,
RecencyScorer,
@@ -149,15 +148,9 @@ class TestRelevanceScorer:
scorer = RelevanceScorer()
contexts = [
KnowledgeContext(
content="Python", source="1", relevance_score=0.8
),
KnowledgeContext(
content="Java", source="2", relevance_score=0.6
),
KnowledgeContext(
content="Go", source="3", relevance_score=0.9
),
KnowledgeContext(content="Python", source="1", relevance_score=0.8),
KnowledgeContext(content="Java", source="2", relevance_score=0.6),
KnowledgeContext(content="Go", source="3", relevance_score=0.9),
]
scores = await scorer.score_batch(contexts, "test")
@@ -263,7 +256,9 @@ class TestRecencyScorer:
)
conv_score = await scorer.score(conv_context, "query", reference_time=now)
knowledge_score = await scorer.score(knowledge_context, "query", reference_time=now)
knowledge_score = await scorer.score(
knowledge_context, "query", reference_time=now
)
# Conversation should decay much faster
assert conv_score < knowledge_score
@@ -301,12 +296,8 @@ class TestRecencyScorer:
contexts = [
TaskContext(content="1", source="t", timestamp=now),
TaskContext(
content="2", source="t", timestamp=now - timedelta(hours=24)
),
TaskContext(
content="3", source="t", timestamp=now - timedelta(hours=48)
),
TaskContext(content="2", source="t", timestamp=now - timedelta(hours=24)),
TaskContext(content="3", source="t", timestamp=now - timedelta(hours=48)),
]
scores = await scorer.score_batch(contexts, "query", reference_time=now)
@@ -508,8 +499,12 @@ class TestCompositeScorer:
assert scored.priority_score > 0.5 # HIGH priority
@pytest.mark.asyncio
async def test_score_cached_on_context(self) -> None:
"""Test that score is cached on the context."""
async def test_score_not_cached_on_context(self) -> None:
"""Test that scores are NOT cached on the context.
Scores should not be cached on the context because they are query-dependent.
Different queries would get incorrect cached scores if we cached on the context.
"""
scorer = CompositeScorer()
context = KnowledgeContext(
@@ -518,14 +513,18 @@ class TestCompositeScorer:
relevance_score=0.5,
)
# First scoring
# After scoring, context._score should remain None
# (we don't cache on context because scores are query-dependent)
await scorer.score(context, "query")
assert context._score is not None
# The scorer should compute fresh scores each time
# rather than caching on the context object
# Second scoring should use cached value
context._score = 0.999 # Set to a known value
score2 = await scorer.score(context, "query")
assert score2 == 0.999
# Score again with different query - should compute fresh score
score1 = await scorer.score(context, "query 1")
score2 = await scorer.score(context, "query 2")
# Both should be valid scores (not necessarily equal since queries differ)
assert 0.0 <= score1 <= 1.0
assert 0.0 <= score2 <= 1.0
@pytest.mark.asyncio
async def test_score_batch(self) -> None:
@@ -555,15 +554,9 @@ class TestCompositeScorer:
scorer = CompositeScorer()
contexts = [
KnowledgeContext(
content="Low", source="docs", relevance_score=0.2
),
KnowledgeContext(
content="High", source="docs", relevance_score=0.9
),
KnowledgeContext(
content="Medium", source="docs", relevance_score=0.5
),
KnowledgeContext(content="Low", source="docs", relevance_score=0.2),
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
KnowledgeContext(content="Medium", source="docs", relevance_score=0.5),
]
ranked = await scorer.rank(contexts, "query")
@@ -580,9 +573,7 @@ class TestCompositeScorer:
scorer = CompositeScorer()
contexts = [
KnowledgeContext(
content=str(i), source="docs", relevance_score=i / 10
)
KnowledgeContext(content=str(i), source="docs", relevance_score=i / 10)
for i in range(10)
]
@@ -595,12 +586,8 @@ class TestCompositeScorer:
scorer = CompositeScorer()
contexts = [
KnowledgeContext(
content="Low", source="docs", relevance_score=0.1
),
KnowledgeContext(
content="High", source="docs", relevance_score=0.9
),
KnowledgeContext(content="Low", source="docs", relevance_score=0.1),
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
]
ranked = await scorer.rank(contexts, "query", min_score=0.5)
@@ -625,7 +612,13 @@ class TestCompositeScorer:
"""
import asyncio
scorer = CompositeScorer()
# Use scorer with recency_weight=0 to eliminate time-dependent variation
# (recency scores change as time passes between calls)
scorer = CompositeScorer(
relevance_weight=0.5,
recency_weight=0.0, # Disable recency to get deterministic results
priority_weight=0.5,
)
# Create a single context that will be scored multiple times concurrently
context = KnowledgeContext(
@@ -639,11 +632,9 @@ class TestCompositeScorer:
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
scores = await asyncio.gather(*tasks)
# All scores should be identical (the same context scored the same way)
# All scores should be identical (deterministic scoring without recency)
assert all(s == scores[0] for s in scores)
# The context should have its _score cached
assert context._score is not None
# Note: We don't cache _score on context because scores are query-dependent
@pytest.mark.asyncio
async def test_concurrent_scoring_different_contexts(self) -> None:
@@ -671,10 +662,7 @@ class TestCompositeScorer:
# Each context should have a different score based on its relevance
assert len(set(scores)) > 1 # Not all the same
# All contexts should have cached scores
for ctx in contexts:
assert ctx._score is not None
# Note: We don't cache _score on context because scores are query-dependent
class TestScoredContext:

View File

@@ -1,20 +1,17 @@
"""Tests for context types."""
import json
from datetime import UTC, datetime, timedelta
import pytest
from app.services.context.types import (
AssembledContext,
BaseContext,
ContextPriority,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskComplexity,
TaskContext,
TaskStatus,
ToolContext,
@@ -181,24 +178,16 @@ class TestKnowledgeContext:
def test_is_code(self) -> None:
"""Test is_code method."""
code_ctx = KnowledgeContext(
content="code", source="test", file_type="python"
)
doc_ctx = KnowledgeContext(
content="docs", source="test", file_type="markdown"
)
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
assert code_ctx.is_code() is True
assert doc_ctx.is_code() is False
def test_is_documentation(self) -> None:
"""Test is_documentation method."""
doc_ctx = KnowledgeContext(
content="docs", source="test", file_type="markdown"
)
code_ctx = KnowledgeContext(
content="code", source="test", file_type="python"
)
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
assert doc_ctx.is_documentation() is True
assert code_ctx.is_documentation() is False
@@ -333,15 +322,11 @@ class TestTaskContext:
def test_status_checks(self) -> None:
"""Test status check methods."""
pending = TaskContext(
content="test", source="test", status=TaskStatus.PENDING
)
pending = TaskContext(content="test", source="test", status=TaskStatus.PENDING)
completed = TaskContext(
content="test", source="test", status=TaskStatus.COMPLETED
)
blocked = TaskContext(
content="test", source="test", status=TaskStatus.BLOCKED
)
blocked = TaskContext(content="test", source="test", status=TaskStatus.BLOCKED)
assert pending.is_active() is True
assert completed.is_complete() is True
@@ -395,12 +380,8 @@ class TestToolContext:
def test_is_successful(self) -> None:
"""Test is_successful method."""
success = ToolContext.from_tool_result(
"test", "ok", ToolResultStatus.SUCCESS
)
error = ToolContext.from_tool_result(
"test", "error", ToolResultStatus.ERROR
)
success = ToolContext.from_tool_result("test", "ok", ToolResultStatus.SUCCESS)
error = ToolContext.from_tool_result("test", "error", ToolResultStatus.ERROR)
assert success.is_successful() is True
assert error.is_successful() is False
@@ -510,9 +491,7 @@ class TestBaseContextMethods:
def test_get_age_seconds(self) -> None:
"""Test get_age_seconds method."""
old_time = datetime.now(UTC) - timedelta(hours=2)
ctx = SystemContext(
content="test", source="test", timestamp=old_time
)
ctx = SystemContext(content="test", source="test", timestamp=old_time)
age = ctx.get_age_seconds()
# Should be approximately 2 hours in seconds
@@ -521,9 +500,7 @@ class TestBaseContextMethods:
def test_get_age_hours(self) -> None:
"""Test get_age_hours method."""
old_time = datetime.now(UTC) - timedelta(hours=5)
ctx = SystemContext(
content="test", source="test", timestamp=old_time
)
ctx = SystemContext(content="test", source="test", timestamp=old_time)
age = ctx.get_age_hours()
assert 4.9 < age < 5.1
@@ -533,12 +510,8 @@ class TestBaseContextMethods:
old_time = datetime.now(UTC) - timedelta(days=10)
new_time = datetime.now(UTC) - timedelta(hours=1)
old_ctx = SystemContext(
content="test", source="test", timestamp=old_time
)
new_ctx = SystemContext(
content="test", source="test", timestamp=new_time
)
old_ctx = SystemContext(content="test", source="test", timestamp=old_time)
new_ctx = SystemContext(content="test", source="test", timestamp=new_time)
# Default max_age is 168 hours (7 days)
assert old_ctx.is_stale() is True