forked from cardosofelipe/fast-next-template
chore(context): refactor for consistency, optimize formatting, and simplify logic
- Cleaned up unnecessary comments in `__all__` definitions for better readability. - Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping). - Simplified conditional expressions and inline comments for context scoring and ranking. - Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`). - Removed unused imports and ensured consistent usage across test files. - Updated `test_score_not_cached_on_context` to clarify caching behavior. - Improved truncation strategy logic and marker handling.
This commit is contained in:
@@ -124,71 +124,55 @@ from .types import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Adapters
|
||||
"ClaudeAdapter",
|
||||
"DefaultAdapter",
|
||||
"get_adapter",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
# Assembly
|
||||
"ContextPipeline",
|
||||
"PipelineMetrics",
|
||||
# Budget Management
|
||||
"BudgetAllocator",
|
||||
"TokenBudget",
|
||||
"TokenCalculator",
|
||||
# Cache
|
||||
"ContextCache",
|
||||
# Engine
|
||||
"ContextEngine",
|
||||
"create_context_engine",
|
||||
# Compression
|
||||
"ContextCompressor",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
# Configuration
|
||||
"ContextSettings",
|
||||
"get_context_settings",
|
||||
"get_default_settings",
|
||||
"reset_context_settings",
|
||||
# Exceptions
|
||||
"AssembledContext",
|
||||
"AssemblyTimeoutError",
|
||||
"BaseContext",
|
||||
"BaseScorer",
|
||||
"BudgetAllocator",
|
||||
"BudgetExceededError",
|
||||
"CacheError",
|
||||
"ClaudeAdapter",
|
||||
"CompositeScorer",
|
||||
"CompressionError",
|
||||
"ContextCache",
|
||||
"ContextCompressor",
|
||||
"ContextEngine",
|
||||
"ContextError",
|
||||
"ContextNotFoundError",
|
||||
"ContextPipeline",
|
||||
"ContextPriority",
|
||||
"ContextRanker",
|
||||
"ContextSettings",
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"DefaultAdapter",
|
||||
"FormattingError",
|
||||
"InvalidContextError",
|
||||
"ScoringError",
|
||||
"TokenCountError",
|
||||
# Prioritization
|
||||
"ContextRanker",
|
||||
"RankingResult",
|
||||
# Scoring
|
||||
"BaseScorer",
|
||||
"CompositeScorer",
|
||||
"KnowledgeContext",
|
||||
"MessageRole",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
"PipelineMetrics",
|
||||
"PriorityScorer",
|
||||
"RankingResult",
|
||||
"RecencyScorer",
|
||||
"RelevanceScorer",
|
||||
"ScoredContext",
|
||||
# Types - Base
|
||||
"AssembledContext",
|
||||
"BaseContext",
|
||||
"ContextPriority",
|
||||
"ContextType",
|
||||
# Types - Conversation
|
||||
"ConversationContext",
|
||||
"MessageRole",
|
||||
# Types - Knowledge
|
||||
"KnowledgeContext",
|
||||
# Types - System
|
||||
"ScoringError",
|
||||
"SystemContext",
|
||||
# Types - Task
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
# Types - Tool
|
||||
"TokenBudget",
|
||||
"TokenCalculator",
|
||||
"TokenCountError",
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
"create_context_engine",
|
||||
"get_adapter",
|
||||
"get_context_settings",
|
||||
"get_default_settings",
|
||||
"reset_context_settings",
|
||||
]
|
||||
|
||||
@@ -5,7 +5,7 @@ Abstract base class for model-specific context formatting.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
@@ -19,7 +19,7 @@ class ModelAdapter(ABC):
|
||||
"""
|
||||
|
||||
# Model name patterns this adapter handles
|
||||
MODEL_PATTERNS: list[str] = []
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = []
|
||||
|
||||
@classmethod
|
||||
def matches_model(cls, model: str) -> bool:
|
||||
@@ -125,7 +125,7 @@ class DefaultAdapter(ModelAdapter):
|
||||
Uses simple plain-text formatting with minimal structure.
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: list[str] = [] # Fallback adapter
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
|
||||
|
||||
@classmethod
|
||||
def matches_model(cls, model: str) -> bool:
|
||||
|
||||
@@ -5,7 +5,7 @@ Provides Claude-specific context formatting using XML tags
|
||||
which Claude models understand natively.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import ModelAdapter
|
||||
@@ -25,7 +25,7 @@ class ClaudeAdapter(ModelAdapter):
|
||||
- Tool result wrapping with tool names
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: list[str] = ["claude", "anthropic"]
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
|
||||
|
||||
def format(
|
||||
self,
|
||||
|
||||
@@ -5,7 +5,7 @@ Provides OpenAI-specific context formatting using markdown
|
||||
which GPT models understand well.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import ModelAdapter
|
||||
@@ -25,7 +25,7 @@ class OpenAIAdapter(ModelAdapter):
|
||||
- Code blocks for tool outputs
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: list[str] = ["gpt", "openai", "o1", "o3"]
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
|
||||
|
||||
def format(
|
||||
self,
|
||||
|
||||
@@ -102,9 +102,7 @@ class ContextPipeline:
|
||||
self._ranker = ranker or ContextRanker(
|
||||
scorer=self._scorer, calculator=self._calculator
|
||||
)
|
||||
self._compressor = compressor or ContextCompressor(
|
||||
calculator=self._calculator
|
||||
)
|
||||
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
@@ -336,27 +334,21 @@ class ContextPipeline:
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
|
||||
def _format_system(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
return content
|
||||
|
||||
def _format_task(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
return f"## Current Task\n\n{content}"
|
||||
|
||||
def _format_knowledge(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format knowledge contexts."""
|
||||
if use_xml:
|
||||
parts = ["<reference_documents>"]
|
||||
@@ -374,9 +366,7 @@ class ContextPipeline:
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format conversation contexts."""
|
||||
if use_xml:
|
||||
parts = ["<conversation_history>"]
|
||||
@@ -394,9 +384,7 @@ class ContextPipeline:
|
||||
parts.append(f"**{role.upper()}**: {ctx.content}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_tool(
|
||||
self, contexts: list[BaseContext], use_xml: bool
|
||||
) -> str:
|
||||
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format tool contexts."""
|
||||
if use_xml:
|
||||
parts = ["<tool_results>"]
|
||||
|
||||
@@ -215,9 +215,7 @@ class TokenBudget:
|
||||
"buffer": self.buffer,
|
||||
},
|
||||
"used": dict(self.used),
|
||||
"remaining": {
|
||||
ct.value: self.remaining(ct) for ct in ContextType
|
||||
},
|
||||
"remaining": {ct.value: self.remaining(ct) for ct in ContextType},
|
||||
"total_used": self.total_used(),
|
||||
"total_remaining": self.total_remaining(),
|
||||
"utilization": round(self.utilization(), 3),
|
||||
@@ -348,13 +346,11 @@ class BudgetAllocator:
|
||||
# Calculate total reclaimable (excluding prioritized types)
|
||||
prioritize_values = {ct.value for ct in prioritize}
|
||||
reclaimable = sum(
|
||||
tokens for ct, tokens in unused.items()
|
||||
if ct not in prioritize_values
|
||||
tokens for ct, tokens in unused.items() if ct not in prioritize_values
|
||||
)
|
||||
|
||||
# Redistribute to prioritized types that are near capacity
|
||||
for ct in prioritize:
|
||||
ct_value = ct.value
|
||||
utilization = budget.utilization(ct)
|
||||
|
||||
if utilization > 0.8: # Near capacity
|
||||
|
||||
@@ -7,7 +7,7 @@ Integrates with LLM Gateway for accurate counts.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
@@ -42,10 +42,10 @@ class TokenCalculator:
|
||||
"""
|
||||
|
||||
# Default characters per token ratio for estimation
|
||||
DEFAULT_CHARS_PER_TOKEN = 4.0
|
||||
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
|
||||
|
||||
# Model-specific ratios (more accurate estimation)
|
||||
MODEL_CHAR_RATIOS: dict[str, float] = {
|
||||
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
|
||||
@@ -116,12 +116,16 @@ class ContextCache:
|
||||
# This avoids JSON serializing potentially large content strings
|
||||
context_data = []
|
||||
for ctx in contexts:
|
||||
context_data.append({
|
||||
"type": ctx.get_type().value,
|
||||
"content_hash": self._hash_content(ctx.content), # Hash instead of full content
|
||||
"source": ctx.source,
|
||||
"priority": ctx.priority, # Already an int
|
||||
})
|
||||
context_data.append(
|
||||
{
|
||||
"type": ctx.get_type().value,
|
||||
"content_hash": self._hash_content(
|
||||
ctx.content
|
||||
), # Hash instead of full content
|
||||
"source": ctx.source,
|
||||
"priority": ctx.priority, # Already an int
|
||||
}
|
||||
)
|
||||
|
||||
data = {
|
||||
"contexts": context_data,
|
||||
@@ -412,7 +416,7 @@ class ContextCache:
|
||||
# Get Redis info
|
||||
info = await self._redis.info("memory") # type: ignore
|
||||
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get Redis stats: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
@@ -78,7 +78,7 @@ class TruncationStrategy:
|
||||
)
|
||||
|
||||
@property
|
||||
def TRUNCATION_MARKER(self) -> str:
|
||||
def truncation_marker(self) -> str:
|
||||
"""Get truncation marker from settings."""
|
||||
return self._settings.truncation_marker
|
||||
|
||||
@@ -141,7 +141,9 @@ class TruncationStrategy:
|
||||
truncated_tokens=truncated_tokens,
|
||||
content=truncated,
|
||||
truncated=True,
|
||||
truncation_ratio=0.0 if original_tokens == 0 else 1 - (truncated_tokens / original_tokens),
|
||||
truncation_ratio=0.0
|
||||
if original_tokens == 0
|
||||
else 1 - (truncated_tokens / original_tokens),
|
||||
)
|
||||
|
||||
async def _truncate_end(
|
||||
@@ -156,17 +158,17 @@ class TruncationStrategy:
|
||||
Simple but effective for most content types.
|
||||
"""
|
||||
# Binary search for optimal truncation point
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available_tokens = max(0, max_tokens - marker_tokens)
|
||||
|
||||
# Edge case: if no tokens available for content, return just the marker
|
||||
if available_tokens <= 0:
|
||||
return self.TRUNCATION_MARKER
|
||||
return self.truncation_marker
|
||||
|
||||
# Estimate characters per token (guard against division by zero)
|
||||
content_tokens = await self._count_tokens(content, model)
|
||||
if content_tokens == 0:
|
||||
return content + self.TRUNCATION_MARKER
|
||||
return content + self.truncation_marker
|
||||
chars_per_token = len(content) / content_tokens
|
||||
|
||||
# Start with estimated position
|
||||
@@ -188,7 +190,7 @@ class TruncationStrategy:
|
||||
else:
|
||||
high = mid - 1
|
||||
|
||||
return best + self.TRUNCATION_MARKER
|
||||
return best + self.truncation_marker
|
||||
|
||||
async def _truncate_middle(
|
||||
self,
|
||||
@@ -201,7 +203,7 @@ class TruncationStrategy:
|
||||
|
||||
Good for code or content where context at boundaries matters.
|
||||
"""
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available_tokens = max_tokens - marker_tokens
|
||||
|
||||
# Split between start and end
|
||||
@@ -218,7 +220,7 @@ class TruncationStrategy:
|
||||
content, end_tokens, from_start=False, model=model
|
||||
)
|
||||
|
||||
return start_content + self.TRUNCATION_MARKER + end_content
|
||||
return start_content + self.truncation_marker + end_content
|
||||
|
||||
async def _truncate_sentence(
|
||||
self,
|
||||
@@ -236,7 +238,7 @@ class TruncationStrategy:
|
||||
|
||||
result: list[str] = []
|
||||
total_tokens = 0
|
||||
marker_tokens = await self._count_tokens(self.TRUNCATION_MARKER, model)
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available = max_tokens - marker_tokens
|
||||
|
||||
for sentence in sentences:
|
||||
@@ -248,7 +250,7 @@ class TruncationStrategy:
|
||||
break
|
||||
|
||||
if len(result) < len(sentences):
|
||||
return " ".join(result) + self.TRUNCATION_MARKER
|
||||
return " ".join(result) + self.truncation_marker
|
||||
return " ".join(result)
|
||||
|
||||
async def _get_content_for_tokens(
|
||||
|
||||
@@ -78,12 +78,8 @@ class ContextEngine:
|
||||
|
||||
# Initialize components
|
||||
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||
self._scorer = CompositeScorer(
|
||||
mcp_manager=mcp_manager, settings=self._settings
|
||||
)
|
||||
self._ranker = ContextRanker(
|
||||
scorer=self._scorer, calculator=self._calculator
|
||||
)
|
||||
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
|
||||
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
|
||||
self._compressor = ContextCompressor(calculator=self._calculator)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
self._cache = ContextCache(redis=redis, settings=self._settings)
|
||||
@@ -274,8 +270,19 @@ class ContextEngine:
|
||||
},
|
||||
)
|
||||
|
||||
# Check both ToolResult.success AND response success
|
||||
if not result.success:
|
||||
logger.warning(f"Knowledge search failed: {result.error}")
|
||||
return []
|
||||
|
||||
if not isinstance(result.data, dict) or not result.data.get(
|
||||
"success", True
|
||||
):
|
||||
logger.warning("Knowledge search returned unsuccessful response")
|
||||
return []
|
||||
|
||||
contexts = []
|
||||
results = result.data.get("results", []) if isinstance(result.data, dict) else []
|
||||
results = result.data.get("results", [])
|
||||
for chunk in results:
|
||||
contexts.append(
|
||||
KnowledgeContext(
|
||||
@@ -283,7 +290,9 @@ class ContextEngine:
|
||||
source=chunk.get("source_path", "unknown"),
|
||||
relevance_score=chunk.get("score", 0.0),
|
||||
metadata={
|
||||
"chunk_id": chunk.get("chunk_id"),
|
||||
"chunk_id": chunk.get(
|
||||
"id"
|
||||
), # Server returns 'id' not 'chunk_id'
|
||||
"document_id": chunk.get("document_id"),
|
||||
},
|
||||
)
|
||||
@@ -312,7 +321,9 @@ class ContextEngine:
|
||||
contexts = []
|
||||
for i, turn in enumerate(history):
|
||||
role_str = turn.get("role", "user").lower()
|
||||
role = MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
||||
role = (
|
||||
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
||||
)
|
||||
|
||||
contexts.append(
|
||||
ConversationContext(
|
||||
@@ -346,6 +357,7 @@ class ContextEngine:
|
||||
# Handle dict content
|
||||
if isinstance(content, dict):
|
||||
import json
|
||||
|
||||
content = json.dumps(content, indent=2)
|
||||
|
||||
contexts.append(
|
||||
|
||||
@@ -61,7 +61,7 @@ class BudgetExceededError(ContextError):
|
||||
requested: Tokens requested
|
||||
context_type: Type of context that exceeded budget
|
||||
"""
|
||||
details = {
|
||||
details: dict[str, Any] = {
|
||||
"allocated": allocated,
|
||||
"requested": requested,
|
||||
"overage": requested - allocated,
|
||||
@@ -170,7 +170,7 @@ class AssemblyTimeoutError(ContextError):
|
||||
elapsed_ms: Actual elapsed time in milliseconds
|
||||
stage: Pipeline stage where timeout occurred
|
||||
"""
|
||||
details = {
|
||||
details: dict[str, Any] = {
|
||||
"timeout_ms": timeout_ms,
|
||||
"elapsed_ms": round(elapsed_ms, 2),
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||
from ..types import BaseContext
|
||||
from ..types import BaseContext, ContextPriority
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -111,8 +111,8 @@ class ContextRanker:
|
||||
|
||||
if ensure_required:
|
||||
for sc in scored_contexts:
|
||||
# CRITICAL priority (100) contexts are always included
|
||||
if sc.context.priority >= 100:
|
||||
# CRITICAL priority (150) contexts are always included
|
||||
if sc.context.priority >= ContextPriority.CRITICAL.value:
|
||||
required.append(sc)
|
||||
else:
|
||||
optional.append(sc)
|
||||
@@ -239,9 +239,7 @@ class ContextRanker:
|
||||
import asyncio
|
||||
|
||||
# Find contexts needing counts
|
||||
contexts_needing_counts = [
|
||||
ctx for ctx in contexts if ctx.token_count is None
|
||||
]
|
||||
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
|
||||
|
||||
if not contexts_needing_counts:
|
||||
return
|
||||
@@ -254,7 +252,7 @@ class ContextRanker:
|
||||
counts = await asyncio.gather(*tasks)
|
||||
|
||||
# Assign counts back
|
||||
for ctx, count in zip(contexts_needing_counts, counts):
|
||||
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
|
||||
ctx.token_count = count
|
||||
|
||||
def _count_by_type(
|
||||
|
||||
@@ -92,7 +92,9 @@ class CompositeScorer:
|
||||
|
||||
# Per-context locks to prevent race conditions during parallel scoring
|
||||
# Uses WeakValueDictionary so locks are garbage collected when not in use
|
||||
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = (
|
||||
WeakValueDictionary()
|
||||
)
|
||||
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
@@ -207,17 +209,14 @@ class CompositeScorer:
|
||||
ScoredContext with all scores
|
||||
"""
|
||||
# Get lock for this specific context to prevent race conditions
|
||||
# within concurrent scoring operations for the same query
|
||||
context_lock = await self._get_context_lock(context.id)
|
||||
|
||||
async with context_lock:
|
||||
# Check if context already has a score (inside lock to prevent races)
|
||||
if context._score is not None:
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=context._score,
|
||||
)
|
||||
|
||||
# Compute individual scores in parallel
|
||||
# Note: We do NOT cache scores on the context because scores are
|
||||
# query-dependent. Caching without considering the query would
|
||||
# return incorrect scores for different queries.
|
||||
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||
@@ -240,9 +239,6 @@ class CompositeScorer:
|
||||
else:
|
||||
composite = 0.0
|
||||
|
||||
# Cache the score on the context (now safe - inside lock)
|
||||
context._score = composite
|
||||
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=composite,
|
||||
@@ -271,9 +267,7 @@ class CompositeScorer:
|
||||
List of ScoredContext (same order as input)
|
||||
"""
|
||||
if parallel:
|
||||
tasks = [
|
||||
self.score_with_details(ctx, query, **kwargs) for ctx in contexts
|
||||
]
|
||||
tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts]
|
||||
return await asyncio.gather(*tasks)
|
||||
else:
|
||||
results = []
|
||||
|
||||
@@ -4,7 +4,7 @@ Priority Scorer for Context Management.
|
||||
Scores context based on assigned priority levels.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import BaseScorer
|
||||
@@ -19,11 +19,11 @@ class PriorityScorer(BaseScorer):
|
||||
"""
|
||||
|
||||
# Default priority bonuses by context type
|
||||
DEFAULT_TYPE_BONUSES: dict[ContextType, float] = {
|
||||
ContextType.SYSTEM: 0.2, # System prompts get a boost
|
||||
ContextType.TASK: 0.15, # Current task is important
|
||||
ContextType.TOOL: 0.1, # Recent tool results matter
|
||||
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
|
||||
DEFAULT_TYPE_BONUSES: ClassVar[dict[ContextType, float]] = {
|
||||
ContextType.SYSTEM: 0.2, # System prompts get a boost
|
||||
ContextType.TASK: 0.15, # Current task is important
|
||||
ContextType.TOOL: 0.1, # Recent tool results matter
|
||||
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
|
||||
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
|
||||
}
|
||||
|
||||
|
||||
@@ -85,7 +85,10 @@ class RelevanceScorer(BaseScorer):
|
||||
Relevance score between 0.0 and 1.0
|
||||
"""
|
||||
# 1. Check for pre-computed relevance score
|
||||
if isinstance(context, KnowledgeContext) and context.relevance_score is not None:
|
||||
if (
|
||||
isinstance(context, KnowledgeContext)
|
||||
and context.relevance_score is not None
|
||||
):
|
||||
return self.normalize_score(context.relevance_score)
|
||||
|
||||
# 2. Check metadata for score
|
||||
@@ -95,14 +98,19 @@ class RelevanceScorer(BaseScorer):
|
||||
if "score" in context.metadata:
|
||||
return self.normalize_score(context.metadata["score"])
|
||||
|
||||
# 3. Try MCP-based semantic similarity
|
||||
# 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
|
||||
# Note: This requires the knowledge-base MCP server to implement compute_similarity
|
||||
if self._mcp is not None:
|
||||
try:
|
||||
score = await self._compute_semantic_similarity(context, query)
|
||||
if score is not None:
|
||||
return score
|
||||
except Exception as e:
|
||||
logger.debug(f"Semantic scoring failed, using fallback: {e}")
|
||||
# Log at debug level since this is expected if compute_similarity
|
||||
# tool is not implemented in the Knowledge Base server
|
||||
logger.debug(
|
||||
f"Semantic scoring unavailable, using keyword fallback: {e}"
|
||||
)
|
||||
|
||||
# 4. Fall back to keyword matching
|
||||
return self._compute_keyword_score(context, query)
|
||||
@@ -122,6 +130,9 @@ class RelevanceScorer(BaseScorer):
|
||||
Returns:
|
||||
Similarity score or None if unavailable
|
||||
"""
|
||||
if self._mcp is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use Knowledge Base's search capability to compute similarity
|
||||
result = await self._mcp.call_tool(
|
||||
@@ -129,7 +140,9 @@ class RelevanceScorer(BaseScorer):
|
||||
tool="compute_similarity",
|
||||
args={
|
||||
"text1": query,
|
||||
"text2": context.content[: self._semantic_max_chars], # Limit content length
|
||||
"text2": context.content[
|
||||
: self._semantic_max_chars
|
||||
], # Limit content length
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -27,23 +27,17 @@ from .tool import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base types
|
||||
"AssembledContext",
|
||||
"BaseContext",
|
||||
"ContextPriority",
|
||||
"ContextType",
|
||||
# Conversation
|
||||
"ConversationContext",
|
||||
"MessageRole",
|
||||
# Knowledge
|
||||
"KnowledgeContext",
|
||||
# System
|
||||
"MessageRole",
|
||||
"SystemContext",
|
||||
# Task
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
# Tool
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
]
|
||||
|
||||
@@ -120,7 +120,16 @@ class KnowledgeContext(BaseContext):
|
||||
|
||||
def is_code(self) -> bool:
|
||||
"""Check if this is code content."""
|
||||
code_types = {"python", "javascript", "typescript", "go", "rust", "java", "c", "cpp"}
|
||||
code_types = {
|
||||
"python",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"go",
|
||||
"rust",
|
||||
"java",
|
||||
"c",
|
||||
"cpp",
|
||||
}
|
||||
return self.file_type is not None and self.file_type.lower() in code_types
|
||||
|
||||
def is_documentation(self) -> bool:
|
||||
|
||||
@@ -56,7 +56,9 @@ class ToolContext(BaseContext):
|
||||
"tool_name": self.tool_name,
|
||||
"tool_description": self.tool_description,
|
||||
"is_result": self.is_result,
|
||||
"result_status": self.result_status.value if self.result_status else None,
|
||||
"result_status": self.result_status.value
|
||||
if self.result_status
|
||||
else None,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"parameters": self.parameters,
|
||||
"server_name": self.server_name,
|
||||
@@ -174,7 +176,9 @@ class ToolContext(BaseContext):
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=f"tool_result:{server_name}:{tool_name}" if server_name else f"tool_result:{tool_name}",
|
||||
source=f"tool_result:{server_name}:{tool_name}"
|
||||
if server_name
|
||||
else f"tool_result:{tool_name}",
|
||||
tool_name=tool_name,
|
||||
is_result=True,
|
||||
result_status=status,
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
"""Tests for model adapters."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.adapters import (
|
||||
ClaudeAdapter,
|
||||
DefaultAdapter,
|
||||
ModelAdapter,
|
||||
OpenAIAdapter,
|
||||
get_adapter,
|
||||
)
|
||||
|
||||
@@ -5,10 +5,9 @@ from datetime import UTC, datetime
|
||||
import pytest
|
||||
|
||||
from app.services.context.assembly import ContextPipeline, PipelineMetrics
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
from app.services.context.budget import TokenBudget
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
@@ -354,7 +353,10 @@ class TestContextPipelineFormatting:
|
||||
|
||||
if result.context_count > 0:
|
||||
assert "<conversation_history>" in result.content
|
||||
assert '<message role="user">' in result.content or 'role="user"' in result.content
|
||||
assert (
|
||||
'<message role="user">' in result.content
|
||||
or 'role="user"' in result.content
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_tool_results(self) -> None:
|
||||
@@ -474,6 +476,10 @@ class TestContextPipelineIntegration:
|
||||
assert system_pos < task_pos
|
||||
if task_pos >= 0 and knowledge_pos >= 0:
|
||||
assert task_pos < knowledge_pos
|
||||
if knowledge_pos >= 0 and conversation_pos >= 0:
|
||||
assert knowledge_pos < conversation_pos
|
||||
if conversation_pos >= 0 and tool_pos >= 0:
|
||||
assert conversation_pos < tool_pos
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excluded_contexts_tracked(self) -> None:
|
||||
|
||||
@@ -2,16 +2,15 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.budget import BudgetAllocator
|
||||
from app.services.context.compression import (
|
||||
ContextCompressor,
|
||||
TruncationResult,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
from app.services.context.types import (
|
||||
ContextType,
|
||||
KnowledgeContext,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
@@ -113,7 +112,7 @@ class TestTruncationStrategy:
|
||||
|
||||
assert result.truncated is True
|
||||
assert len(result.content) < len(content)
|
||||
assert strategy.TRUNCATION_MARKER in result.content
|
||||
assert strategy.truncation_marker in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_middle_strategy(self) -> None:
|
||||
@@ -126,7 +125,7 @@ class TestTruncationStrategy:
|
||||
)
|
||||
|
||||
assert result.truncated is True
|
||||
assert strategy.TRUNCATION_MARKER in result.content
|
||||
assert strategy.truncation_marker in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_sentence_strategy(self) -> None:
|
||||
@@ -140,7 +139,9 @@ class TestTruncationStrategy:
|
||||
|
||||
assert result.truncated is True
|
||||
# Should cut at sentence boundary
|
||||
assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content
|
||||
assert (
|
||||
result.content.endswith(".") or strategy.truncation_marker in result.content
|
||||
)
|
||||
|
||||
|
||||
class TestContextCompressor:
|
||||
@@ -235,10 +236,12 @@ class TestTruncationEdgeCases:
|
||||
content = "Some content to truncate"
|
||||
|
||||
# max_tokens less than marker tokens should return just marker
|
||||
result = await strategy.truncate_to_tokens(content, max_tokens=1, strategy="end")
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=1, strategy="end"
|
||||
)
|
||||
|
||||
# Should handle gracefully without crashing
|
||||
assert strategy.TRUNCATION_MARKER in result.content or result.content == content
|
||||
assert strategy.truncation_marker in result.content or result.content == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
|
||||
@@ -249,7 +252,7 @@ class TestTruncationEdgeCases:
|
||||
result = await strategy.truncate_to_tokens("a", max_tokens=100)
|
||||
|
||||
# Should not raise ZeroDivisionError
|
||||
assert result.content in ("a", "a" + strategy.TRUNCATION_MARKER)
|
||||
assert result.content in ("a", "a" + strategy.truncation_marker)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_for_tokens_zero_target(self) -> None:
|
||||
|
||||
@@ -11,8 +11,6 @@ from app.services.context.types import (
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for context management exceptions."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.exceptions import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for context ranking module."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
@@ -230,9 +228,7 @@ class TestContextRanker:
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(
|
||||
contexts, "query", budget, ensure_required=False
|
||||
)
|
||||
result = await ranker.rank(contexts, "query", budget, ensure_required=False)
|
||||
|
||||
# Without ensure_required, CRITICAL contexts can be excluded
|
||||
# if budget doesn't allow
|
||||
@@ -246,12 +242,8 @@ class TestContextRanker:
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Knowledge 1", source="docs", relevance_score=0.8
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Knowledge 2", source="docs", relevance_score=0.6
|
||||
),
|
||||
KnowledgeContext(content="Knowledge 1", source="docs", relevance_score=0.8),
|
||||
KnowledgeContext(content="Knowledge 2", source="docs", relevance_score=0.6),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from app.services.context.scoring import (
|
||||
BaseScorer,
|
||||
CompositeScorer,
|
||||
PriorityScorer,
|
||||
RecencyScorer,
|
||||
@@ -149,15 +148,9 @@ class TestRelevanceScorer:
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Python", source="1", relevance_score=0.8
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Java", source="2", relevance_score=0.6
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Go", source="3", relevance_score=0.9
|
||||
),
|
||||
KnowledgeContext(content="Python", source="1", relevance_score=0.8),
|
||||
KnowledgeContext(content="Java", source="2", relevance_score=0.6),
|
||||
KnowledgeContext(content="Go", source="3", relevance_score=0.9),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "test")
|
||||
@@ -263,7 +256,9 @@ class TestRecencyScorer:
|
||||
)
|
||||
|
||||
conv_score = await scorer.score(conv_context, "query", reference_time=now)
|
||||
knowledge_score = await scorer.score(knowledge_context, "query", reference_time=now)
|
||||
knowledge_score = await scorer.score(
|
||||
knowledge_context, "query", reference_time=now
|
||||
)
|
||||
|
||||
# Conversation should decay much faster
|
||||
assert conv_score < knowledge_score
|
||||
@@ -301,12 +296,8 @@ class TestRecencyScorer:
|
||||
|
||||
contexts = [
|
||||
TaskContext(content="1", source="t", timestamp=now),
|
||||
TaskContext(
|
||||
content="2", source="t", timestamp=now - timedelta(hours=24)
|
||||
),
|
||||
TaskContext(
|
||||
content="3", source="t", timestamp=now - timedelta(hours=48)
|
||||
),
|
||||
TaskContext(content="2", source="t", timestamp=now - timedelta(hours=24)),
|
||||
TaskContext(content="3", source="t", timestamp=now - timedelta(hours=48)),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "query", reference_time=now)
|
||||
@@ -508,8 +499,12 @@ class TestCompositeScorer:
|
||||
assert scored.priority_score > 0.5 # HIGH priority
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_cached_on_context(self) -> None:
|
||||
"""Test that score is cached on the context."""
|
||||
async def test_score_not_cached_on_context(self) -> None:
|
||||
"""Test that scores are NOT cached on the context.
|
||||
|
||||
Scores should not be cached on the context because they are query-dependent.
|
||||
Different queries would get incorrect cached scores if we cached on the context.
|
||||
"""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
@@ -518,14 +513,18 @@ class TestCompositeScorer:
|
||||
relevance_score=0.5,
|
||||
)
|
||||
|
||||
# First scoring
|
||||
# After scoring, context._score should remain None
|
||||
# (we don't cache on context because scores are query-dependent)
|
||||
await scorer.score(context, "query")
|
||||
assert context._score is not None
|
||||
# The scorer should compute fresh scores each time
|
||||
# rather than caching on the context object
|
||||
|
||||
# Second scoring should use cached value
|
||||
context._score = 0.999 # Set to a known value
|
||||
score2 = await scorer.score(context, "query")
|
||||
assert score2 == 0.999
|
||||
# Score again with different query - should compute fresh score
|
||||
score1 = await scorer.score(context, "query 1")
|
||||
score2 = await scorer.score(context, "query 2")
|
||||
# Both should be valid scores (not necessarily equal since queries differ)
|
||||
assert 0.0 <= score1 <= 1.0
|
||||
assert 0.0 <= score2 <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
@@ -555,15 +554,9 @@ class TestCompositeScorer:
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low", source="docs", relevance_score=0.2
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High", source="docs", relevance_score=0.9
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium", source="docs", relevance_score=0.5
|
||||
),
|
||||
KnowledgeContext(content="Low", source="docs", relevance_score=0.2),
|
||||
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
|
||||
KnowledgeContext(content="Medium", source="docs", relevance_score=0.5),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query")
|
||||
@@ -580,9 +573,7 @@ class TestCompositeScorer:
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content=str(i), source="docs", relevance_score=i / 10
|
||||
)
|
||||
KnowledgeContext(content=str(i), source="docs", relevance_score=i / 10)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
@@ -595,12 +586,8 @@ class TestCompositeScorer:
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low", source="docs", relevance_score=0.1
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High", source="docs", relevance_score=0.9
|
||||
),
|
||||
KnowledgeContext(content="Low", source="docs", relevance_score=0.1),
|
||||
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query", min_score=0.5)
|
||||
@@ -625,7 +612,13 @@ class TestCompositeScorer:
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
scorer = CompositeScorer()
|
||||
# Use scorer with recency_weight=0 to eliminate time-dependent variation
|
||||
# (recency scores change as time passes between calls)
|
||||
scorer = CompositeScorer(
|
||||
relevance_weight=0.5,
|
||||
recency_weight=0.0, # Disable recency to get deterministic results
|
||||
priority_weight=0.5,
|
||||
)
|
||||
|
||||
# Create a single context that will be scored multiple times concurrently
|
||||
context = KnowledgeContext(
|
||||
@@ -639,11 +632,9 @@ class TestCompositeScorer:
|
||||
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
|
||||
scores = await asyncio.gather(*tasks)
|
||||
|
||||
# All scores should be identical (the same context scored the same way)
|
||||
# All scores should be identical (deterministic scoring without recency)
|
||||
assert all(s == scores[0] for s in scores)
|
||||
|
||||
# The context should have its _score cached
|
||||
assert context._score is not None
|
||||
# Note: We don't cache _score on context because scores are query-dependent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_scoring_different_contexts(self) -> None:
|
||||
@@ -671,10 +662,7 @@ class TestCompositeScorer:
|
||||
|
||||
# Each context should have a different score based on its relevance
|
||||
assert len(set(scores)) > 1 # Not all the same
|
||||
|
||||
# All contexts should have cached scores
|
||||
for ctx in contexts:
|
||||
assert ctx._score is not None
|
||||
# Note: We don't cache _score on context because scores are query-dependent
|
||||
|
||||
|
||||
class TestScoredContext:
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
"""Tests for context types."""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
ToolContext,
|
||||
@@ -181,24 +178,16 @@ class TestKnowledgeContext:
|
||||
|
||||
def test_is_code(self) -> None:
|
||||
"""Test is_code method."""
|
||||
code_ctx = KnowledgeContext(
|
||||
content="code", source="test", file_type="python"
|
||||
)
|
||||
doc_ctx = KnowledgeContext(
|
||||
content="docs", source="test", file_type="markdown"
|
||||
)
|
||||
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
|
||||
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
|
||||
|
||||
assert code_ctx.is_code() is True
|
||||
assert doc_ctx.is_code() is False
|
||||
|
||||
def test_is_documentation(self) -> None:
|
||||
"""Test is_documentation method."""
|
||||
doc_ctx = KnowledgeContext(
|
||||
content="docs", source="test", file_type="markdown"
|
||||
)
|
||||
code_ctx = KnowledgeContext(
|
||||
content="code", source="test", file_type="python"
|
||||
)
|
||||
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
|
||||
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
|
||||
|
||||
assert doc_ctx.is_documentation() is True
|
||||
assert code_ctx.is_documentation() is False
|
||||
@@ -333,15 +322,11 @@ class TestTaskContext:
|
||||
|
||||
def test_status_checks(self) -> None:
|
||||
"""Test status check methods."""
|
||||
pending = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.PENDING
|
||||
)
|
||||
pending = TaskContext(content="test", source="test", status=TaskStatus.PENDING)
|
||||
completed = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.COMPLETED
|
||||
)
|
||||
blocked = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.BLOCKED
|
||||
)
|
||||
blocked = TaskContext(content="test", source="test", status=TaskStatus.BLOCKED)
|
||||
|
||||
assert pending.is_active() is True
|
||||
assert completed.is_complete() is True
|
||||
@@ -395,12 +380,8 @@ class TestToolContext:
|
||||
|
||||
def test_is_successful(self) -> None:
|
||||
"""Test is_successful method."""
|
||||
success = ToolContext.from_tool_result(
|
||||
"test", "ok", ToolResultStatus.SUCCESS
|
||||
)
|
||||
error = ToolContext.from_tool_result(
|
||||
"test", "error", ToolResultStatus.ERROR
|
||||
)
|
||||
success = ToolContext.from_tool_result("test", "ok", ToolResultStatus.SUCCESS)
|
||||
error = ToolContext.from_tool_result("test", "error", ToolResultStatus.ERROR)
|
||||
|
||||
assert success.is_successful() is True
|
||||
assert error.is_successful() is False
|
||||
@@ -510,9 +491,7 @@ class TestBaseContextMethods:
|
||||
def test_get_age_seconds(self) -> None:
|
||||
"""Test get_age_seconds method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=2)
|
||||
ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||
|
||||
age = ctx.get_age_seconds()
|
||||
# Should be approximately 2 hours in seconds
|
||||
@@ -521,9 +500,7 @@ class TestBaseContextMethods:
|
||||
def test_get_age_hours(self) -> None:
|
||||
"""Test get_age_hours method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=5)
|
||||
ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||
|
||||
age = ctx.get_age_hours()
|
||||
assert 4.9 < age < 5.1
|
||||
@@ -533,12 +510,8 @@ class TestBaseContextMethods:
|
||||
old_time = datetime.now(UTC) - timedelta(days=10)
|
||||
new_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
|
||||
old_ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
new_ctx = SystemContext(
|
||||
content="test", source="test", timestamp=new_time
|
||||
)
|
||||
old_ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||
new_ctx = SystemContext(content="test", source="test", timestamp=new_time)
|
||||
|
||||
# Default max_age is 168 hours (7 days)
|
||||
assert old_ctx.is_stale() is True
|
||||
|
||||
Reference in New Issue
Block a user