diff --git a/backend/app/services/context/__init__.py b/backend/app/services/context/__init__.py
index d0bd1cb..9be69e5 100644
--- a/backend/app/services/context/__init__.py
+++ b/backend/app/services/context/__init__.py
@@ -124,71 +124,55 @@ from .types import (
)
__all__ = [
- # Adapters
- "ClaudeAdapter",
- "DefaultAdapter",
- "get_adapter",
- "ModelAdapter",
- "OpenAIAdapter",
- # Assembly
- "ContextPipeline",
- "PipelineMetrics",
- # Budget Management
- "BudgetAllocator",
- "TokenBudget",
- "TokenCalculator",
- # Cache
- "ContextCache",
- # Engine
- "ContextEngine",
- "create_context_engine",
- # Compression
- "ContextCompressor",
- "TruncationResult",
- "TruncationStrategy",
- # Configuration
- "ContextSettings",
- "get_context_settings",
- "get_default_settings",
- "reset_context_settings",
- # Exceptions
+ "AssembledContext",
"AssemblyTimeoutError",
+ "BaseContext",
+ "BaseScorer",
+ "BudgetAllocator",
"BudgetExceededError",
"CacheError",
+ "ClaudeAdapter",
+ "CompositeScorer",
"CompressionError",
+ "ContextCache",
+ "ContextCompressor",
+ "ContextEngine",
"ContextError",
"ContextNotFoundError",
+ "ContextPipeline",
+ "ContextPriority",
+ "ContextRanker",
+ "ContextSettings",
+ "ContextType",
+ "ConversationContext",
+ "DefaultAdapter",
"FormattingError",
"InvalidContextError",
- "ScoringError",
- "TokenCountError",
- # Prioritization
- "ContextRanker",
- "RankingResult",
- # Scoring
- "BaseScorer",
- "CompositeScorer",
+ "KnowledgeContext",
+ "MessageRole",
+ "ModelAdapter",
+ "OpenAIAdapter",
+ "PipelineMetrics",
"PriorityScorer",
+ "RankingResult",
"RecencyScorer",
"RelevanceScorer",
"ScoredContext",
- # Types - Base
- "AssembledContext",
- "BaseContext",
- "ContextPriority",
- "ContextType",
- # Types - Conversation
- "ConversationContext",
- "MessageRole",
- # Types - Knowledge
- "KnowledgeContext",
- # Types - System
+ "ScoringError",
"SystemContext",
- # Types - Task
"TaskComplexity",
"TaskContext",
"TaskStatus",
- # Types - Tool
+ "TokenBudget",
+ "TokenCalculator",
+ "TokenCountError",
"ToolContext",
"ToolResultStatus",
+ "TruncationResult",
+ "TruncationStrategy",
+ "create_context_engine",
+ "get_adapter",
+ "get_context_settings",
+ "get_default_settings",
+ "reset_context_settings",
]
diff --git a/backend/app/services/context/adapters/base.py b/backend/app/services/context/adapters/base.py
index cd0d6a0..967ac11 100644
--- a/backend/app/services/context/adapters/base.py
+++ b/backend/app/services/context/adapters/base.py
@@ -5,7 +5,7 @@ Abstract base class for model-specific context formatting.
"""
from abc import ABC, abstractmethod
-from typing import Any
+from typing import Any, ClassVar
from ..types import BaseContext, ContextType
@@ -19,7 +19,7 @@ class ModelAdapter(ABC):
"""
# Model name patterns this adapter handles
- MODEL_PATTERNS: list[str] = []
+ MODEL_PATTERNS: ClassVar[list[str]] = []
@classmethod
def matches_model(cls, model: str) -> bool:
@@ -125,7 +125,7 @@ class DefaultAdapter(ModelAdapter):
Uses simple plain-text formatting with minimal structure.
"""
- MODEL_PATTERNS: list[str] = [] # Fallback adapter
+ MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
@classmethod
def matches_model(cls, model: str) -> bool:
diff --git a/backend/app/services/context/adapters/claude.py b/backend/app/services/context/adapters/claude.py
index 0c0e253..2fc1a4e 100644
--- a/backend/app/services/context/adapters/claude.py
+++ b/backend/app/services/context/adapters/claude.py
@@ -5,7 +5,7 @@ Provides Claude-specific context formatting using XML tags
which Claude models understand natively.
"""
-from typing import Any
+from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import ModelAdapter
@@ -25,7 +25,7 @@ class ClaudeAdapter(ModelAdapter):
- Tool result wrapping with tool names
"""
- MODEL_PATTERNS: list[str] = ["claude", "anthropic"]
+ MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
def format(
self,
diff --git a/backend/app/services/context/adapters/openai.py b/backend/app/services/context/adapters/openai.py
index 40304b7..dd6ffa6 100644
--- a/backend/app/services/context/adapters/openai.py
+++ b/backend/app/services/context/adapters/openai.py
@@ -5,7 +5,7 @@ Provides OpenAI-specific context formatting using markdown
which GPT models understand well.
"""
-from typing import Any
+from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import ModelAdapter
@@ -25,7 +25,7 @@ class OpenAIAdapter(ModelAdapter):
- Code blocks for tool outputs
"""
- MODEL_PATTERNS: list[str] = ["gpt", "openai", "o1", "o3"]
+ MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
def format(
self,
diff --git a/backend/app/services/context/assembly/pipeline.py b/backend/app/services/context/assembly/pipeline.py
index 2003cec..af1c8cf 100644
--- a/backend/app/services/context/assembly/pipeline.py
+++ b/backend/app/services/context/assembly/pipeline.py
@@ -102,9 +102,7 @@ class ContextPipeline:
self._ranker = ranker or ContextRanker(
scorer=self._scorer, calculator=self._calculator
)
- self._compressor = compressor or ContextCompressor(
- calculator=self._calculator
- )
+ self._compressor = compressor or ContextCompressor(calculator=self._calculator)
self._allocator = BudgetAllocator(self._settings)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
@@ -336,27 +334,21 @@ class ContextPipeline:
return "\n".join(c.content for c in contexts)
- def _format_system(
- self, contexts: list[BaseContext], use_xml: bool
- ) -> str:
+ def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"\n{content}\n"
return content
- def _format_task(
- self, contexts: list[BaseContext], use_xml: bool
- ) -> str:
+ def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"\n{content}\n"
return f"## Current Task\n\n{content}"
- def _format_knowledge(
- self, contexts: list[BaseContext], use_xml: bool
- ) -> str:
+ def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format knowledge contexts."""
if use_xml:
parts = [""]
@@ -374,9 +366,7 @@ class ContextPipeline:
parts.append("")
return "\n".join(parts)
- def _format_conversation(
- self, contexts: list[BaseContext], use_xml: bool
- ) -> str:
+ def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format conversation contexts."""
if use_xml:
parts = [""]
@@ -394,9 +384,7 @@ class ContextPipeline:
parts.append(f"**{role.upper()}**: {ctx.content}")
return "\n\n".join(parts)
- def _format_tool(
- self, contexts: list[BaseContext], use_xml: bool
- ) -> str:
+ def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format tool contexts."""
if use_xml:
parts = [""]
diff --git a/backend/app/services/context/budget/allocator.py b/backend/app/services/context/budget/allocator.py
index 00e5cc9..ee33894 100644
--- a/backend/app/services/context/budget/allocator.py
+++ b/backend/app/services/context/budget/allocator.py
@@ -215,9 +215,7 @@ class TokenBudget:
"buffer": self.buffer,
},
"used": dict(self.used),
- "remaining": {
- ct.value: self.remaining(ct) for ct in ContextType
- },
+ "remaining": {ct.value: self.remaining(ct) for ct in ContextType},
"total_used": self.total_used(),
"total_remaining": self.total_remaining(),
"utilization": round(self.utilization(), 3),
@@ -348,13 +346,11 @@ class BudgetAllocator:
# Calculate total reclaimable (excluding prioritized types)
prioritize_values = {ct.value for ct in prioritize}
reclaimable = sum(
- tokens for ct, tokens in unused.items()
- if ct not in prioritize_values
+ tokens for ct, tokens in unused.items() if ct not in prioritize_values
)
# Redistribute to prioritized types that are near capacity
for ct in prioritize:
- ct_value = ct.value
utilization = budget.utilization(ct)
if utilization > 0.8: # Near capacity
diff --git a/backend/app/services/context/budget/calculator.py b/backend/app/services/context/budget/calculator.py
index 3ad7b75..356271f 100644
--- a/backend/app/services/context/budget/calculator.py
+++ b/backend/app/services/context/budget/calculator.py
@@ -7,7 +7,7 @@ Integrates with LLM Gateway for accurate counts.
import hashlib
import logging
-from typing import TYPE_CHECKING, Any, Protocol
+from typing import TYPE_CHECKING, Any, ClassVar, Protocol
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
@@ -42,10 +42,10 @@ class TokenCalculator:
"""
# Default characters per token ratio for estimation
- DEFAULT_CHARS_PER_TOKEN = 4.0
+ DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
# Model-specific ratios (more accurate estimation)
- MODEL_CHAR_RATIOS: dict[str, float] = {
+ MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
"claude": 3.5,
"gpt-4": 4.0,
"gpt-3.5": 4.0,
diff --git a/backend/app/services/context/cache/context_cache.py b/backend/app/services/context/cache/context_cache.py
index 6549dbd..7b26132 100644
--- a/backend/app/services/context/cache/context_cache.py
+++ b/backend/app/services/context/cache/context_cache.py
@@ -116,12 +116,16 @@ class ContextCache:
# This avoids JSON serializing potentially large content strings
context_data = []
for ctx in contexts:
- context_data.append({
- "type": ctx.get_type().value,
- "content_hash": self._hash_content(ctx.content), # Hash instead of full content
- "source": ctx.source,
- "priority": ctx.priority, # Already an int
- })
+ context_data.append(
+ {
+ "type": ctx.get_type().value,
+ "content_hash": self._hash_content(
+ ctx.content
+ ), # Hash instead of full content
+ "source": ctx.source,
+ "priority": ctx.priority, # Already an int
+ }
+ )
data = {
"contexts": context_data,
@@ -412,7 +416,7 @@ class ContextCache:
# Get Redis info
info = await self._redis.info("memory") # type: ignore
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
- except Exception:
- pass
+ except Exception as e:
+ logger.debug(f"Failed to get Redis stats: {e}")
return stats
diff --git a/backend/app/services/context/compression/truncation.py b/backend/app/services/context/compression/truncation.py
index 50afecd..058a894 100644
--- a/backend/app/services/context/compression/truncation.py
+++ b/backend/app/services/context/compression/truncation.py
@@ -78,7 +78,7 @@ class TruncationStrategy:
)
@property
- def TRUNCATION_MARKER(self) -> str:
+ def truncation_marker(self) -> str:
"""Get truncation marker from settings."""
return self._settings.truncation_marker
@@ -141,7 +141,9 @@ class TruncationStrategy:
truncated_tokens=truncated_tokens,
content=truncated,
truncated=True,
- truncation_ratio=0.0 if original_tokens == 0 else 1 - (truncated_tokens / original_tokens),
+ truncation_ratio=0.0
+ if original_tokens == 0
+ else 1 - (truncated_tokens / original_tokens),
)
async def _truncate_end(
@@ -156,17 +158,17 @@ class TruncationStrategy:
Simple but effective for most content types.
"""
# Binary search for optimal truncation point
- marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
+ marker_tokens = await self._count_tokens(self.truncation_marker, model)
available_tokens = max(0, max_tokens - marker_tokens)
# Edge case: if no tokens available for content, return just the marker
if available_tokens <= 0:
- return self.TRUNCATION_MARKER
+ return self.truncation_marker
# Estimate characters per token (guard against division by zero)
content_tokens = await self._count_tokens(content, model)
if content_tokens == 0:
- return content + self.TRUNCATION_MARKER
+ return content + self.truncation_marker
chars_per_token = len(content) / content_tokens
# Start with estimated position
@@ -188,7 +190,7 @@ class TruncationStrategy:
else:
high = mid - 1
- return best + self.TRUNCATION_MARKER
+ return best + self.truncation_marker
async def _truncate_middle(
self,
@@ -201,7 +203,7 @@ class TruncationStrategy:
Good for code or content where context at boundaries matters.
"""
- marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
+ marker_tokens = await self._count_tokens(self.truncation_marker, model)
available_tokens = max_tokens - marker_tokens
# Split between start and end
@@ -218,7 +220,7 @@ class TruncationStrategy:
content, end_tokens, from_start=False, model=model
)
- return start_content + self.TRUNCATION_MARKER + end_content
+ return start_content + self.truncation_marker + end_content
async def _truncate_sentence(
self,
@@ -236,7 +238,7 @@ class TruncationStrategy:
result: list[str] = []
total_tokens = 0
- marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
+ marker_tokens = await self._count_tokens(self.truncation_marker, model)
available = max_tokens - marker_tokens
for sentence in sentences:
@@ -248,7 +250,7 @@ class TruncationStrategy:
break
if len(result) < len(sentences):
- return " ".join(result) + self.TRUNCATION_MARKER
+ return " ".join(result) + self.truncation_marker
return " ".join(result)
async def _get_content_for_tokens(
diff --git a/backend/app/services/context/engine.py b/backend/app/services/context/engine.py
index 33618a2..39a190a 100644
--- a/backend/app/services/context/engine.py
+++ b/backend/app/services/context/engine.py
@@ -78,12 +78,8 @@ class ContextEngine:
# Initialize components
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
- self._scorer = CompositeScorer(
- mcp_manager=mcp_manager, settings=self._settings
- )
- self._ranker = ContextRanker(
- scorer=self._scorer, calculator=self._calculator
- )
+ self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
+ self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
self._compressor = ContextCompressor(calculator=self._calculator)
self._allocator = BudgetAllocator(self._settings)
self._cache = ContextCache(redis=redis, settings=self._settings)
@@ -274,8 +270,19 @@ class ContextEngine:
},
)
+ # Check both ToolResult.success AND response success
+ if not result.success:
+ logger.warning(f"Knowledge search failed: {result.error}")
+ return []
+
+ if not isinstance(result.data, dict) or not result.data.get(
+ "success", True
+ ):
+ logger.warning("Knowledge search returned unsuccessful response")
+ return []
+
contexts = []
- results = result.data.get("results", []) if isinstance(result.data, dict) else []
+ results = result.data.get("results", [])
for chunk in results:
contexts.append(
KnowledgeContext(
@@ -283,7 +290,9 @@ class ContextEngine:
source=chunk.get("source_path", "unknown"),
relevance_score=chunk.get("score", 0.0),
metadata={
- "chunk_id": chunk.get("chunk_id"),
+ "chunk_id": chunk.get(
+ "id"
+ ), # Server returns 'id' not 'chunk_id'
"document_id": chunk.get("document_id"),
},
)
@@ -312,7 +321,9 @@ class ContextEngine:
contexts = []
for i, turn in enumerate(history):
role_str = turn.get("role", "user").lower()
- role = MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
+ role = (
+ MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
+ )
contexts.append(
ConversationContext(
@@ -346,6 +357,7 @@ class ContextEngine:
# Handle dict content
if isinstance(content, dict):
import json
+
content = json.dumps(content, indent=2)
contexts.append(
diff --git a/backend/app/services/context/exceptions.py b/backend/app/services/context/exceptions.py
index 18f7910..5ae1233 100644
--- a/backend/app/services/context/exceptions.py
+++ b/backend/app/services/context/exceptions.py
@@ -61,7 +61,7 @@ class BudgetExceededError(ContextError):
requested: Tokens requested
context_type: Type of context that exceeded budget
"""
- details = {
+ details: dict[str, Any] = {
"allocated": allocated,
"requested": requested,
"overage": requested - allocated,
@@ -170,7 +170,7 @@ class AssemblyTimeoutError(ContextError):
elapsed_ms: Actual elapsed time in milliseconds
stage: Pipeline stage where timeout occurred
"""
- details = {
+ details: dict[str, Any] = {
"timeout_ms": timeout_ms,
"elapsed_ms": round(elapsed_ms, 2),
}
diff --git a/backend/app/services/context/prioritization/ranker.py b/backend/app/services/context/prioritization/ranker.py
index fd2e812..b475b6c 100644
--- a/backend/app/services/context/prioritization/ranker.py
+++ b/backend/app/services/context/prioritization/ranker.py
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any
from ..budget import TokenBudget, TokenCalculator
from ..config import ContextSettings, get_context_settings
from ..scoring.composite import CompositeScorer, ScoredContext
-from ..types import BaseContext
+from ..types import BaseContext, ContextPriority
if TYPE_CHECKING:
pass
@@ -111,8 +111,8 @@ class ContextRanker:
if ensure_required:
for sc in scored_contexts:
- # CRITICAL priority (100) contexts are always included
- if sc.context.priority >= 100:
+ # CRITICAL priority (150) contexts are always included
+ if sc.context.priority >= ContextPriority.CRITICAL.value:
required.append(sc)
else:
optional.append(sc)
@@ -239,9 +239,7 @@ class ContextRanker:
import asyncio
# Find contexts needing counts
- contexts_needing_counts = [
- ctx for ctx in contexts if ctx.token_count is None
- ]
+ contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
if not contexts_needing_counts:
return
@@ -254,7 +252,7 @@ class ContextRanker:
counts = await asyncio.gather(*tasks)
# Assign counts back
- for ctx, count in zip(contexts_needing_counts, counts):
+ for ctx, count in zip(contexts_needing_counts, counts, strict=True):
ctx.token_count = count
def _count_by_type(
diff --git a/backend/app/services/context/scoring/composite.py b/backend/app/services/context/scoring/composite.py
index 9e4cc8e..a75ebf6 100644
--- a/backend/app/services/context/scoring/composite.py
+++ b/backend/app/services/context/scoring/composite.py
@@ -92,7 +92,9 @@ class CompositeScorer:
# Per-context locks to prevent race conditions during parallel scoring
# Uses WeakValueDictionary so locks are garbage collected when not in use
- self._context_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
+ self._context_locks: WeakValueDictionary[str, asyncio.Lock] = (
+ WeakValueDictionary()
+ )
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
@@ -207,17 +209,14 @@ class CompositeScorer:
ScoredContext with all scores
"""
# Get lock for this specific context to prevent race conditions
+ # within concurrent scoring operations for the same query
context_lock = await self._get_context_lock(context.id)
async with context_lock:
- # Check if context already has a score (inside lock to prevent races)
- if context._score is not None:
- return ScoredContext(
- context=context,
- composite_score=context._score,
- )
-
# Compute individual scores in parallel
+ # Note: We do NOT cache scores on the context because scores are
+ # query-dependent. Caching without considering the query would
+ # return incorrect scores for different queries.
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
recency_task = self._recency_scorer.score(context, query, **kwargs)
priority_task = self._priority_scorer.score(context, query, **kwargs)
@@ -240,9 +239,6 @@ class CompositeScorer:
else:
composite = 0.0
- # Cache the score on the context (now safe - inside lock)
- context._score = composite
-
return ScoredContext(
context=context,
composite_score=composite,
@@ -271,9 +267,7 @@ class CompositeScorer:
List of ScoredContext (same order as input)
"""
if parallel:
- tasks = [
- self.score_with_details(ctx, query, **kwargs) for ctx in contexts
- ]
+ tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts]
return await asyncio.gather(*tasks)
else:
results = []
diff --git a/backend/app/services/context/scoring/priority.py b/backend/app/services/context/scoring/priority.py
index 1f8e2c6..1d26ab6 100644
--- a/backend/app/services/context/scoring/priority.py
+++ b/backend/app/services/context/scoring/priority.py
@@ -4,7 +4,7 @@ Priority Scorer for Context Management.
Scores context based on assigned priority levels.
"""
-from typing import Any
+from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import BaseScorer
@@ -19,11 +19,11 @@ class PriorityScorer(BaseScorer):
"""
# Default priority bonuses by context type
- DEFAULT_TYPE_BONUSES: dict[ContextType, float] = {
- ContextType.SYSTEM: 0.2, # System prompts get a boost
- ContextType.TASK: 0.15, # Current task is important
- ContextType.TOOL: 0.1, # Recent tool results matter
- ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
+ DEFAULT_TYPE_BONUSES: ClassVar[dict[ContextType, float]] = {
+ ContextType.SYSTEM: 0.2, # System prompts get a boost
+ ContextType.TASK: 0.15, # Current task is important
+ ContextType.TOOL: 0.1, # Recent tool results matter
+ ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
}
diff --git a/backend/app/services/context/scoring/relevance.py b/backend/app/services/context/scoring/relevance.py
index a3a66f7..ac57ccc 100644
--- a/backend/app/services/context/scoring/relevance.py
+++ b/backend/app/services/context/scoring/relevance.py
@@ -85,7 +85,10 @@ class RelevanceScorer(BaseScorer):
Relevance score between 0.0 and 1.0
"""
# 1. Check for pre-computed relevance score
- if isinstance(context, KnowledgeContext) and context.relevance_score is not None:
+ if (
+ isinstance(context, KnowledgeContext)
+ and context.relevance_score is not None
+ ):
return self.normalize_score(context.relevance_score)
# 2. Check metadata for score
@@ -95,14 +98,19 @@ class RelevanceScorer(BaseScorer):
if "score" in context.metadata:
return self.normalize_score(context.metadata["score"])
- # 3. Try MCP-based semantic similarity
+ # 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
+ # Note: This requires the knowledge-base MCP server to implement compute_similarity
if self._mcp is not None:
try:
score = await self._compute_semantic_similarity(context, query)
if score is not None:
return score
except Exception as e:
- logger.debug(f"Semantic scoring failed, using fallback: {e}")
+ # Log at debug level since this is expected if compute_similarity
+ # tool is not implemented in the Knowledge Base server
+ logger.debug(
+ f"Semantic scoring unavailable, using keyword fallback: {e}"
+ )
# 4. Fall back to keyword matching
return self._compute_keyword_score(context, query)
@@ -122,6 +130,9 @@ class RelevanceScorer(BaseScorer):
Returns:
Similarity score or None if unavailable
"""
+ if self._mcp is None:
+ return None
+
try:
# Use Knowledge Base's search capability to compute similarity
result = await self._mcp.call_tool(
@@ -129,7 +140,9 @@ class RelevanceScorer(BaseScorer):
tool="compute_similarity",
args={
"text1": query,
- "text2": context.content[: self._semantic_max_chars], # Limit content length
+ "text2": context.content[
+ : self._semantic_max_chars
+ ], # Limit content length
},
)
diff --git a/backend/app/services/context/types/__init__.py b/backend/app/services/context/types/__init__.py
index d247bfb..4304025 100644
--- a/backend/app/services/context/types/__init__.py
+++ b/backend/app/services/context/types/__init__.py
@@ -27,23 +27,17 @@ from .tool import (
)
__all__ = [
- # Base types
"AssembledContext",
"BaseContext",
"ContextPriority",
"ContextType",
- # Conversation
"ConversationContext",
- "MessageRole",
- # Knowledge
"KnowledgeContext",
- # System
+ "MessageRole",
"SystemContext",
- # Task
"TaskComplexity",
"TaskContext",
"TaskStatus",
- # Tool
"ToolContext",
"ToolResultStatus",
]
diff --git a/backend/app/services/context/types/knowledge.py b/backend/app/services/context/types/knowledge.py
index 9e66819..242312e 100644
--- a/backend/app/services/context/types/knowledge.py
+++ b/backend/app/services/context/types/knowledge.py
@@ -120,7 +120,16 @@ class KnowledgeContext(BaseContext):
def is_code(self) -> bool:
"""Check if this is code content."""
- code_types = {"python", "javascript", "typescript", "go", "rust", "java", "c", "cpp"}
+ code_types = {
+ "python",
+ "javascript",
+ "typescript",
+ "go",
+ "rust",
+ "java",
+ "c",
+ "cpp",
+ }
return self.file_type is not None and self.file_type.lower() in code_types
def is_documentation(self) -> bool:
diff --git a/backend/app/services/context/types/tool.py b/backend/app/services/context/types/tool.py
index e4c1678..2d39756 100644
--- a/backend/app/services/context/types/tool.py
+++ b/backend/app/services/context/types/tool.py
@@ -56,7 +56,9 @@ class ToolContext(BaseContext):
"tool_name": self.tool_name,
"tool_description": self.tool_description,
"is_result": self.is_result,
- "result_status": self.result_status.value if self.result_status else None,
+ "result_status": self.result_status.value
+ if self.result_status
+ else None,
"execution_time_ms": self.execution_time_ms,
"parameters": self.parameters,
"server_name": self.server_name,
@@ -174,7 +176,9 @@ class ToolContext(BaseContext):
return cls(
content=content,
- source=f"tool_result:{server_name}:{tool_name}" if server_name else f"tool_result:{tool_name}",
+ source=f"tool_result:{server_name}:{tool_name}"
+ if server_name
+ else f"tool_result:{tool_name}",
tool_name=tool_name,
is_result=True,
result_status=status,
diff --git a/backend/tests/services/context/test_adapters.py b/backend/tests/services/context/test_adapters.py
index 9013d7f..fd29240 100644
--- a/backend/tests/services/context/test_adapters.py
+++ b/backend/tests/services/context/test_adapters.py
@@ -1,11 +1,8 @@
"""Tests for model adapters."""
-import pytest
-
from app.services.context.adapters import (
ClaudeAdapter,
DefaultAdapter,
- ModelAdapter,
OpenAIAdapter,
get_adapter,
)
diff --git a/backend/tests/services/context/test_assembly.py b/backend/tests/services/context/test_assembly.py
index 92f9c7e..fff2069 100644
--- a/backend/tests/services/context/test_assembly.py
+++ b/backend/tests/services/context/test_assembly.py
@@ -5,10 +5,9 @@ from datetime import UTC, datetime
import pytest
from app.services.context.assembly import ContextPipeline, PipelineMetrics
-from app.services.context.budget import BudgetAllocator, TokenBudget
+from app.services.context.budget import TokenBudget
from app.services.context.types import (
AssembledContext,
- ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
@@ -354,7 +353,10 @@ class TestContextPipelineFormatting:
if result.context_count > 0:
assert "" in result.content
- assert '' in result.content or 'role="user"' in result.content
+ assert (
+ '' in result.content
+ or 'role="user"' in result.content
+ )
@pytest.mark.asyncio
async def test_format_tool_results(self) -> None:
@@ -474,6 +476,10 @@ class TestContextPipelineIntegration:
assert system_pos < task_pos
if task_pos >= 0 and knowledge_pos >= 0:
assert task_pos < knowledge_pos
+ if knowledge_pos >= 0 and conversation_pos >= 0:
+ assert knowledge_pos < conversation_pos
+ if conversation_pos >= 0 and tool_pos >= 0:
+ assert conversation_pos < tool_pos
@pytest.mark.asyncio
async def test_excluded_contexts_tracked(self) -> None:
diff --git a/backend/tests/services/context/test_compression.py b/backend/tests/services/context/test_compression.py
index c37ca10..3a24db2 100644
--- a/backend/tests/services/context/test_compression.py
+++ b/backend/tests/services/context/test_compression.py
@@ -2,16 +2,15 @@
import pytest
+from app.services.context.budget import BudgetAllocator
from app.services.context.compression import (
ContextCompressor,
TruncationResult,
TruncationStrategy,
)
-from app.services.context.budget import BudgetAllocator, TokenBudget
from app.services.context.types import (
ContextType,
KnowledgeContext,
- SystemContext,
TaskContext,
)
@@ -113,7 +112,7 @@ class TestTruncationStrategy:
assert result.truncated is True
assert len(result.content) < len(content)
- assert strategy.TRUNCATION_MARKER in result.content
+ assert strategy.truncation_marker in result.content
@pytest.mark.asyncio
async def test_truncate_middle_strategy(self) -> None:
@@ -126,7 +125,7 @@ class TestTruncationStrategy:
)
assert result.truncated is True
- assert strategy.TRUNCATION_MARKER in result.content
+ assert strategy.truncation_marker in result.content
@pytest.mark.asyncio
async def test_truncate_sentence_strategy(self) -> None:
@@ -140,7 +139,9 @@ class TestTruncationStrategy:
assert result.truncated is True
# Should cut at sentence boundary
- assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content
+ assert (
+ result.content.endswith(".") or strategy.truncation_marker in result.content
+ )
class TestContextCompressor:
@@ -235,10 +236,12 @@ class TestTruncationEdgeCases:
content = "Some content to truncate"
# max_tokens less than marker tokens should return just marker
- result = await strategy.truncate_to_tokens(content, max_tokens=1, strategy="end")
+ result = await strategy.truncate_to_tokens(
+ content, max_tokens=1, strategy="end"
+ )
# Should handle gracefully without crashing
- assert strategy.TRUNCATION_MARKER in result.content or result.content == content
+ assert strategy.truncation_marker in result.content or result.content == content
@pytest.mark.asyncio
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
@@ -249,7 +252,7 @@ class TestTruncationEdgeCases:
result = await strategy.truncate_to_tokens("a", max_tokens=100)
# Should not raise ZeroDivisionError
- assert result.content in ("a", "a" + strategy.TRUNCATION_MARKER)
+ assert result.content in ("a", "a" + strategy.truncation_marker)
@pytest.mark.asyncio
async def test_get_content_for_tokens_zero_target(self) -> None:
diff --git a/backend/tests/services/context/test_engine.py b/backend/tests/services/context/test_engine.py
index 87202f7..1b5a0d9 100644
--- a/backend/tests/services/context/test_engine.py
+++ b/backend/tests/services/context/test_engine.py
@@ -11,8 +11,6 @@ from app.services.context.types import (
ConversationContext,
KnowledgeContext,
MessageRole,
- SystemContext,
- TaskContext,
ToolContext,
)
diff --git a/backend/tests/services/context/test_exceptions.py b/backend/tests/services/context/test_exceptions.py
index f987f76..2ec5d2b 100644
--- a/backend/tests/services/context/test_exceptions.py
+++ b/backend/tests/services/context/test_exceptions.py
@@ -1,7 +1,5 @@
"""Tests for context management exceptions."""
-import pytest
-
from app.services.context.exceptions import (
AssemblyTimeoutError,
BudgetExceededError,
diff --git a/backend/tests/services/context/test_ranker.py b/backend/tests/services/context/test_ranker.py
index adf876c..bd98382 100644
--- a/backend/tests/services/context/test_ranker.py
+++ b/backend/tests/services/context/test_ranker.py
@@ -1,7 +1,5 @@
"""Tests for context ranking module."""
-from datetime import UTC, datetime
-
import pytest
from app.services.context.budget import BudgetAllocator, TokenBudget
@@ -230,9 +228,7 @@ class TestContextRanker:
),
]
- result = await ranker.rank(
- contexts, "query", budget, ensure_required=False
- )
+ result = await ranker.rank(contexts, "query", budget, ensure_required=False)
# Without ensure_required, CRITICAL contexts can be excluded
# if budget doesn't allow
@@ -246,12 +242,8 @@ class TestContextRanker:
budget = allocator.create_budget(10000)
contexts = [
- KnowledgeContext(
- content="Knowledge 1", source="docs", relevance_score=0.8
- ),
- KnowledgeContext(
- content="Knowledge 2", source="docs", relevance_score=0.6
- ),
+ KnowledgeContext(content="Knowledge 1", source="docs", relevance_score=0.8),
+ KnowledgeContext(content="Knowledge 2", source="docs", relevance_score=0.6),
TaskContext(content="Task", source="task"),
]
diff --git a/backend/tests/services/context/test_scoring.py b/backend/tests/services/context/test_scoring.py
index 1feeea6..37eb858 100644
--- a/backend/tests/services/context/test_scoring.py
+++ b/backend/tests/services/context/test_scoring.py
@@ -6,7 +6,6 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.context.scoring import (
- BaseScorer,
CompositeScorer,
PriorityScorer,
RecencyScorer,
@@ -149,15 +148,9 @@ class TestRelevanceScorer:
scorer = RelevanceScorer()
contexts = [
- KnowledgeContext(
- content="Python", source="1", relevance_score=0.8
- ),
- KnowledgeContext(
- content="Java", source="2", relevance_score=0.6
- ),
- KnowledgeContext(
- content="Go", source="3", relevance_score=0.9
- ),
+ KnowledgeContext(content="Python", source="1", relevance_score=0.8),
+ KnowledgeContext(content="Java", source="2", relevance_score=0.6),
+ KnowledgeContext(content="Go", source="3", relevance_score=0.9),
]
scores = await scorer.score_batch(contexts, "test")
@@ -263,7 +256,9 @@ class TestRecencyScorer:
)
conv_score = await scorer.score(conv_context, "query", reference_time=now)
- knowledge_score = await scorer.score(knowledge_context, "query", reference_time=now)
+ knowledge_score = await scorer.score(
+ knowledge_context, "query", reference_time=now
+ )
# Conversation should decay much faster
assert conv_score < knowledge_score
@@ -301,12 +296,8 @@ class TestRecencyScorer:
contexts = [
TaskContext(content="1", source="t", timestamp=now),
- TaskContext(
- content="2", source="t", timestamp=now - timedelta(hours=24)
- ),
- TaskContext(
- content="3", source="t", timestamp=now - timedelta(hours=48)
- ),
+ TaskContext(content="2", source="t", timestamp=now - timedelta(hours=24)),
+ TaskContext(content="3", source="t", timestamp=now - timedelta(hours=48)),
]
scores = await scorer.score_batch(contexts, "query", reference_time=now)
@@ -508,8 +499,12 @@ class TestCompositeScorer:
assert scored.priority_score > 0.5 # HIGH priority
@pytest.mark.asyncio
- async def test_score_cached_on_context(self) -> None:
- """Test that score is cached on the context."""
+ async def test_score_not_cached_on_context(self) -> None:
+ """Test that scores are NOT cached on the context.
+
+ Scores should not be cached on the context because they are query-dependent.
+ Different queries would get incorrect cached scores if we cached on the context.
+ """
scorer = CompositeScorer()
context = KnowledgeContext(
@@ -518,14 +513,18 @@ class TestCompositeScorer:
relevance_score=0.5,
)
- # First scoring
+ # After scoring, context._score should remain None
+ # (we don't cache on context because scores are query-dependent)
await scorer.score(context, "query")
- assert context._score is not None
+ # The scorer should compute fresh scores each time
+ # rather than caching on the context object
- # Second scoring should use cached value
- context._score = 0.999 # Set to a known value
- score2 = await scorer.score(context, "query")
- assert score2 == 0.999
+ # Score again with different query - should compute fresh score
+ score1 = await scorer.score(context, "query 1")
+ score2 = await scorer.score(context, "query 2")
+ # Both should be valid scores (not necessarily equal since queries differ)
+ assert 0.0 <= score1 <= 1.0
+ assert 0.0 <= score2 <= 1.0
@pytest.mark.asyncio
async def test_score_batch(self) -> None:
@@ -555,15 +554,9 @@ class TestCompositeScorer:
scorer = CompositeScorer()
contexts = [
- KnowledgeContext(
- content="Low", source="docs", relevance_score=0.2
- ),
- KnowledgeContext(
- content="High", source="docs", relevance_score=0.9
- ),
- KnowledgeContext(
- content="Medium", source="docs", relevance_score=0.5
- ),
+ KnowledgeContext(content="Low", source="docs", relevance_score=0.2),
+ KnowledgeContext(content="High", source="docs", relevance_score=0.9),
+ KnowledgeContext(content="Medium", source="docs", relevance_score=0.5),
]
ranked = await scorer.rank(contexts, "query")
@@ -580,9 +573,7 @@ class TestCompositeScorer:
scorer = CompositeScorer()
contexts = [
- KnowledgeContext(
- content=str(i), source="docs", relevance_score=i / 10
- )
+ KnowledgeContext(content=str(i), source="docs", relevance_score=i / 10)
for i in range(10)
]
@@ -595,12 +586,8 @@ class TestCompositeScorer:
scorer = CompositeScorer()
contexts = [
- KnowledgeContext(
- content="Low", source="docs", relevance_score=0.1
- ),
- KnowledgeContext(
- content="High", source="docs", relevance_score=0.9
- ),
+ KnowledgeContext(content="Low", source="docs", relevance_score=0.1),
+ KnowledgeContext(content="High", source="docs", relevance_score=0.9),
]
ranked = await scorer.rank(contexts, "query", min_score=0.5)
@@ -625,7 +612,13 @@ class TestCompositeScorer:
"""
import asyncio
- scorer = CompositeScorer()
+ # Use scorer with recency_weight=0 to eliminate time-dependent variation
+ # (recency scores change as time passes between calls)
+ scorer = CompositeScorer(
+ relevance_weight=0.5,
+ recency_weight=0.0, # Disable recency to get deterministic results
+ priority_weight=0.5,
+ )
# Create a single context that will be scored multiple times concurrently
context = KnowledgeContext(
@@ -639,11 +632,9 @@ class TestCompositeScorer:
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
scores = await asyncio.gather(*tasks)
- # All scores should be identical (the same context scored the same way)
+ # All scores should be identical (deterministic scoring without recency)
assert all(s == scores[0] for s in scores)
-
- # The context should have its _score cached
- assert context._score is not None
+ # Note: We don't cache _score on context because scores are query-dependent
@pytest.mark.asyncio
async def test_concurrent_scoring_different_contexts(self) -> None:
@@ -671,10 +662,7 @@ class TestCompositeScorer:
# Each context should have a different score based on its relevance
assert len(set(scores)) > 1 # Not all the same
-
- # All contexts should have cached scores
- for ctx in contexts:
- assert ctx._score is not None
+ # Note: We don't cache _score on context because scores are query-dependent
class TestScoredContext:
diff --git a/backend/tests/services/context/test_types.py b/backend/tests/services/context/test_types.py
index 2a5743e..82291dc 100644
--- a/backend/tests/services/context/test_types.py
+++ b/backend/tests/services/context/test_types.py
@@ -1,20 +1,17 @@
"""Tests for context types."""
-import json
from datetime import UTC, datetime, timedelta
import pytest
from app.services.context.types import (
AssembledContext,
- BaseContext,
ContextPriority,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
- TaskComplexity,
TaskContext,
TaskStatus,
ToolContext,
@@ -181,24 +178,16 @@ class TestKnowledgeContext:
def test_is_code(self) -> None:
"""Test is_code method."""
- code_ctx = KnowledgeContext(
- content="code", source="test", file_type="python"
- )
- doc_ctx = KnowledgeContext(
- content="docs", source="test", file_type="markdown"
- )
+ code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
+ doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
assert code_ctx.is_code() is True
assert doc_ctx.is_code() is False
def test_is_documentation(self) -> None:
"""Test is_documentation method."""
- doc_ctx = KnowledgeContext(
- content="docs", source="test", file_type="markdown"
- )
- code_ctx = KnowledgeContext(
- content="code", source="test", file_type="python"
- )
+ doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
+ code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
assert doc_ctx.is_documentation() is True
assert code_ctx.is_documentation() is False
@@ -333,15 +322,11 @@ class TestTaskContext:
def test_status_checks(self) -> None:
"""Test status check methods."""
- pending = TaskContext(
- content="test", source="test", status=TaskStatus.PENDING
- )
+ pending = TaskContext(content="test", source="test", status=TaskStatus.PENDING)
completed = TaskContext(
content="test", source="test", status=TaskStatus.COMPLETED
)
- blocked = TaskContext(
- content="test", source="test", status=TaskStatus.BLOCKED
- )
+ blocked = TaskContext(content="test", source="test", status=TaskStatus.BLOCKED)
assert pending.is_active() is True
assert completed.is_complete() is True
@@ -395,12 +380,8 @@ class TestToolContext:
def test_is_successful(self) -> None:
"""Test is_successful method."""
- success = ToolContext.from_tool_result(
- "test", "ok", ToolResultStatus.SUCCESS
- )
- error = ToolContext.from_tool_result(
- "test", "error", ToolResultStatus.ERROR
- )
+ success = ToolContext.from_tool_result("test", "ok", ToolResultStatus.SUCCESS)
+ error = ToolContext.from_tool_result("test", "error", ToolResultStatus.ERROR)
assert success.is_successful() is True
assert error.is_successful() is False
@@ -510,9 +491,7 @@ class TestBaseContextMethods:
def test_get_age_seconds(self) -> None:
"""Test get_age_seconds method."""
old_time = datetime.now(UTC) - timedelta(hours=2)
- ctx = SystemContext(
- content="test", source="test", timestamp=old_time
- )
+ ctx = SystemContext(content="test", source="test", timestamp=old_time)
age = ctx.get_age_seconds()
# Should be approximately 2 hours in seconds
@@ -521,9 +500,7 @@ class TestBaseContextMethods:
def test_get_age_hours(self) -> None:
"""Test get_age_hours method."""
old_time = datetime.now(UTC) - timedelta(hours=5)
- ctx = SystemContext(
- content="test", source="test", timestamp=old_time
- )
+ ctx = SystemContext(content="test", source="test", timestamp=old_time)
age = ctx.get_age_hours()
assert 4.9 < age < 5.1
@@ -533,12 +510,8 @@ class TestBaseContextMethods:
old_time = datetime.now(UTC) - timedelta(days=10)
new_time = datetime.now(UTC) - timedelta(hours=1)
- old_ctx = SystemContext(
- content="test", source="test", timestamp=old_time
- )
- new_ctx = SystemContext(
- content="test", source="test", timestamp=new_time
- )
+ old_ctx = SystemContext(content="test", source="test", timestamp=old_time)
+ new_ctx = SystemContext(content="test", source="test", timestamp=new_time)
# Default max_age is 168 hours (7 days)
assert old_ctx.is_stale() is True