forked from cardosofelipe/fast-next-template
feat(context): enhance timeout handling, tenant isolation, and budget management
- Added timeout enforcement for token counting, scoring, and compression with detailed error handling. - Introduced tenant isolation in context caching using project and agent identifiers. - Enhanced budget management with stricter checks for critical context overspending and buffer limitations. - Optimized per-context locking with cleanup to prevent memory leaks in concurrent environments. - Updated default assembly timeout settings for improved performance and reliability. - Improved XML escaping in Claude adapter for safety against injection attacks. - Standardized token estimation using model-specific ratios.
This commit is contained in:
@@ -94,12 +94,13 @@ class ClaudeAdapter(ModelAdapter):
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
# System prompts are typically admin-controlled, but escape for safety
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||
@@ -107,12 +108,14 @@ class ClaudeAdapter(ModelAdapter):
|
||||
Format knowledge contexts as structured documents.
|
||||
|
||||
Each knowledge context becomes a document with source attribution.
|
||||
All content is XML-escaped to prevent injection attacks.
|
||||
"""
|
||||
parts = ["<reference_documents>"]
|
||||
|
||||
for ctx in contexts:
|
||||
source = self._escape_xml(ctx.source)
|
||||
content = ctx.content
|
||||
# Escape content to prevent XML injection
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||
|
||||
if score:
|
||||
@@ -131,13 +134,16 @@ class ClaudeAdapter(ModelAdapter):
|
||||
Format conversation contexts as message history.
|
||||
|
||||
Uses role-based message tags for clear turn delineation.
|
||||
All content is XML-escaped to prevent prompt injection.
|
||||
"""
|
||||
parts = ["<conversation_history>"]
|
||||
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
role = self._escape_xml(ctx.metadata.get("role", "user"))
|
||||
# Escape content to prevent prompt injection via fake XML tags
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
parts.append(f'<message role="{role}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append(content)
|
||||
parts.append("</message>")
|
||||
|
||||
parts.append("</conversation_history>")
|
||||
@@ -148,19 +154,23 @@ class ClaudeAdapter(ModelAdapter):
|
||||
Format tool contexts as tool results.
|
||||
|
||||
Each tool result is wrapped with the tool name.
|
||||
All content is XML-escaped to prevent injection.
|
||||
"""
|
||||
parts = ["<tool_results>"]
|
||||
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
|
||||
status = ctx.metadata.get("status", "")
|
||||
|
||||
if status:
|
||||
parts.append(f'<tool_result name="{tool_name}" status="{status}">')
|
||||
parts.append(
|
||||
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
|
||||
)
|
||||
else:
|
||||
parts.append(f'<tool_result name="{tool_name}">')
|
||||
|
||||
parts.append(ctx.content)
|
||||
# Escape content to prevent injection
|
||||
parts.append(self._escape_xml_content(ctx.content))
|
||||
parts.append("</tool_result>")
|
||||
|
||||
parts.append("</tool_results>")
|
||||
@@ -176,3 +186,21 @@ class ClaudeAdapter(ModelAdapter):
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml_content(text: str) -> str:
|
||||
"""
|
||||
Escape XML special characters in element content.
|
||||
|
||||
This prevents XML injection attacks where malicious content
|
||||
could break out of XML tags or inject fake tags for prompt injection.
|
||||
|
||||
Only escapes &, <, > since quotes don't need escaping in content.
|
||||
|
||||
Args:
|
||||
text: Content text to escape
|
||||
|
||||
Returns:
|
||||
XML-safe content string
|
||||
"""
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
@@ -12,6 +12,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..adapters import get_adapter
|
||||
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from ..compression.truncation import ContextCompressor
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
@@ -156,20 +157,42 @@ class ContextPipeline:
|
||||
else:
|
||||
budget = self._allocator.create_budget_for_model(model)
|
||||
|
||||
# 1. Count tokens for all contexts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
# 1. Count tokens for all contexts (with timeout enforcement)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._ensure_token_counts(contexts, model),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during token counting",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
|
||||
# Check timeout
|
||||
# Check timeout (handles edge case where operation finished just at limit)
|
||||
self._check_timeout(start, timeout, "token counting")
|
||||
|
||||
# 2. Score and rank contexts
|
||||
# 2. Score and rank contexts (with timeout enforcement)
|
||||
scoring_start = time.perf_counter()
|
||||
ranking_result = await self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
)
|
||||
try:
|
||||
ranking_result = await asyncio.wait_for(
|
||||
self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during scoring/ranking",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||
|
||||
selected_contexts = ranking_result.selected_contexts
|
||||
@@ -179,12 +202,23 @@ class ContextPipeline:
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "scoring")
|
||||
|
||||
# 3. Compress if needed and enabled
|
||||
# 3. Compress if needed and enabled (with timeout enforcement)
|
||||
if compress and self._needs_compression(selected_contexts, budget):
|
||||
compression_start = time.perf_counter()
|
||||
selected_contexts = await self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
)
|
||||
try:
|
||||
selected_contexts = await asyncio.wait_for(
|
||||
self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during compression",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.compression_time_ms = (
|
||||
time.perf_counter() - compression_start
|
||||
) * 1000
|
||||
@@ -280,129 +314,18 @@ class ContextPipeline:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Groups contexts by type and applies model-specific formatting.
|
||||
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
||||
to format contexts optimally for each model family.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to format
|
||||
model: Target model name
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
# Group by type
|
||||
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
if ct not in by_type:
|
||||
by_type[ct] = []
|
||||
by_type[ct].append(context)
|
||||
|
||||
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
|
||||
type_order = [
|
||||
ContextType.SYSTEM,
|
||||
ContextType.TASK,
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.CONVERSATION,
|
||||
ContextType.TOOL,
|
||||
]
|
||||
|
||||
parts: list[str] = []
|
||||
for ct in type_order:
|
||||
if ct in by_type:
|
||||
formatted = self._format_type(by_type[ct], ct, model)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
model: str,
|
||||
) -> str:
|
||||
"""Format contexts of a specific type."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# Check if model prefers XML tags (Claude)
|
||||
use_xml = "claude" in model.lower()
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts, use_xml)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts, use_xml)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts, use_xml)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts, use_xml)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts, use_xml)
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
return content
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
return f"## Current Task\n\n{content}"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format knowledge contexts."""
|
||||
if use_xml:
|
||||
parts = ["<reference_documents>"]
|
||||
for ctx in contexts:
|
||||
parts.append(f'<document source="{ctx.source}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</document>")
|
||||
parts.append("</reference_documents>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = ["## Reference Documents\n"]
|
||||
for ctx in contexts:
|
||||
parts.append(f"### Source: {ctx.source}\n")
|
||||
parts.append(ctx.content)
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format conversation contexts."""
|
||||
if use_xml:
|
||||
parts = ["<conversation_history>"]
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
parts.append(f'<message role="{role}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</message>")
|
||||
parts.append("</conversation_history>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = []
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
parts.append(f"**{role.upper()}**: {ctx.content}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format tool contexts."""
|
||||
if use_xml:
|
||||
parts = ["<tool_results>"]
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
parts.append(f'<tool_result name="{tool_name}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</tool_result>")
|
||||
parts.append("</tool_results>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = ["## Recent Tool Results\n"]
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
parts.append(f"### Tool: {tool_name}\n")
|
||||
parts.append(f"```\n{ctx.content}\n```")
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
adapter = get_adapter(model)
|
||||
return adapter.format(contexts)
|
||||
|
||||
def _check_timeout(
|
||||
self,
|
||||
@@ -412,9 +335,28 @@ class ContextPipeline:
|
||||
) -> None:
|
||||
"""Check if timeout exceeded and raise if so."""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
if elapsed_ms > timeout_ms:
|
||||
if elapsed_ms >= timeout_ms:
|
||||
raise AssemblyTimeoutError(
|
||||
message=f"Context assembly timed out during {phase}",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
||||
"""
|
||||
Calculate remaining timeout in seconds for asyncio.wait_for.
|
||||
|
||||
Returns at least a small positive value to avoid immediate timeout
|
||||
edge cases with wait_for.
|
||||
|
||||
Args:
|
||||
start: Start time from time.perf_counter()
|
||||
timeout_ms: Total timeout in milliseconds
|
||||
|
||||
Returns:
|
||||
Remaining timeout in seconds (minimum 0.001)
|
||||
"""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
remaining_ms = timeout_ms - elapsed_ms
|
||||
# Return at least 1ms to avoid zero/negative timeout edge cases
|
||||
return max(remaining_ms / 1000.0, 0.001)
|
||||
|
||||
@@ -293,14 +293,18 @@ class BudgetAllocator:
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
# Calculate adjustment (limited by buffer)
|
||||
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
|
||||
if adjustment > 0:
|
||||
# Taking from buffer
|
||||
# Taking from buffer - limited by available buffer
|
||||
actual_adjustment = min(adjustment, budget.buffer)
|
||||
budget.buffer -= actual_adjustment
|
||||
else:
|
||||
# Returning to buffer
|
||||
actual_adjustment = adjustment
|
||||
# Returning to buffer - limited by current allocation of target type
|
||||
current_allocation = budget.get_allocation(context_type)
|
||||
# Can't return more than current allocation
|
||||
actual_adjustment = max(adjustment, -current_allocation)
|
||||
# Add returned tokens back to buffer (adjustment is negative, so subtract)
|
||||
budget.buffer -= actual_adjustment
|
||||
|
||||
# Apply to target type
|
||||
if context_type == "system":
|
||||
|
||||
@@ -95,19 +95,28 @@ class ContextCache:
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
model: str,
|
||||
project_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compute a fingerprint for a context assembly request.
|
||||
|
||||
The fingerprint is based on:
|
||||
- Project and agent IDs (for tenant isolation)
|
||||
- Context content hash and metadata (not full content for performance)
|
||||
- Query string
|
||||
- Target model
|
||||
|
||||
SECURITY: project_id and agent_id MUST be included to prevent
|
||||
cross-tenant cache pollution. Without these, one tenant could
|
||||
receive cached contexts from another tenant with the same query.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts
|
||||
query: Query string
|
||||
model: Model name
|
||||
project_id: Project ID for tenant isolation
|
||||
agent_id: Agent ID for tenant isolation
|
||||
|
||||
Returns:
|
||||
32-character hex fingerprint
|
||||
@@ -128,6 +137,9 @@ class ContextCache:
|
||||
)
|
||||
|
||||
data = {
|
||||
# CRITICAL: Include tenant identifiers for cache isolation
|
||||
"project_id": project_id or "",
|
||||
"agent_id": agent_id or "",
|
||||
"contexts": context_data,
|
||||
"query": query,
|
||||
"model": model,
|
||||
|
||||
@@ -19,6 +19,40 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _estimate_tokens(text: str, model: str | None = None) -> int:
|
||||
"""
|
||||
Estimate token count using model-specific character ratios.
|
||||
|
||||
Module-level function for reuse across classes. Uses the same ratios
|
||||
as TokenCalculator for consistency.
|
||||
|
||||
Args:
|
||||
text: Text to estimate tokens for
|
||||
model: Optional model name for model-specific ratios
|
||||
|
||||
Returns:
|
||||
Estimated token count (minimum 1)
|
||||
"""
|
||||
# Model-specific character ratios (chars per token)
|
||||
model_ratios = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
"gemini": 4.0,
|
||||
}
|
||||
default_ratio = 4.0
|
||||
|
||||
ratio = default_ratio
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for model_prefix, model_ratio in model_ratios.items():
|
||||
if model_prefix in model_lower:
|
||||
ratio = model_ratio
|
||||
break
|
||||
|
||||
return max(1, int(len(text) / ratio))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationResult:
|
||||
"""Result of truncation operation."""
|
||||
@@ -284,8 +318,8 @@ class TruncationStrategy:
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
|
||||
# Fallback estimation
|
||||
return max(1, len(text) // 4)
|
||||
# Fallback estimation with model-specific ratios
|
||||
return _estimate_tokens(text, model)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
@@ -415,4 +449,5 @@ class ContextCompressor:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
return max(1, len(text) // 4)
|
||||
# Use model-specific estimation for consistency
|
||||
return _estimate_tokens(text, model)
|
||||
|
||||
@@ -149,10 +149,11 @@ class ContextSettings(BaseSettings):
|
||||
|
||||
# Performance settings
|
||||
max_assembly_time_ms: int = Field(
|
||||
default=100,
|
||||
default=2000,
|
||||
ge=10,
|
||||
le=5000,
|
||||
description="Maximum time for context assembly in milliseconds",
|
||||
le=30000,
|
||||
description="Maximum time for context assembly in milliseconds. "
|
||||
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
|
||||
)
|
||||
parallel_scoring: bool = Field(
|
||||
default=True,
|
||||
|
||||
@@ -212,7 +212,10 @@ class ContextEngine:
|
||||
# Check cache if enabled
|
||||
fingerprint: str | None = None
|
||||
if use_cache and self._cache.is_enabled:
|
||||
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
||||
# Include project_id and agent_id for tenant isolation
|
||||
fingerprint = self._cache.compute_fingerprint(
|
||||
contexts, query, model, project_id=project_id, agent_id=agent_id
|
||||
)
|
||||
cached = await self._cache.get_assembled(fingerprint)
|
||||
if cached:
|
||||
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||
from ..types import BaseContext, ContextPriority
|
||||
|
||||
@@ -127,6 +128,9 @@ class ContextRanker:
|
||||
excluded: list[ScoredContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
# Calculate the usable budget (total minus reserved portions)
|
||||
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||
|
||||
# First, try to fit required contexts
|
||||
for sc in required:
|
||||
token_count = sc.context.token_count or 0
|
||||
@@ -137,7 +141,20 @@ class ContextRanker:
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
# Force-fit CRITICAL contexts if needed
|
||||
# Force-fit CRITICAL contexts if needed, but check total budget first
|
||||
if total_tokens + token_count > usable_budget:
|
||||
# Even CRITICAL contexts cannot exceed total model context window
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"CRITICAL contexts exceed total budget. "
|
||||
f"Context '{sc.context.source}' ({token_count} tokens) "
|
||||
f"would exceed usable budget of {usable_budget} tokens."
|
||||
),
|
||||
allocated=usable_budget,
|
||||
requested=total_tokens + token_count,
|
||||
context_type="CRITICAL_OVERFLOW",
|
||||
)
|
||||
|
||||
budget.allocate(context_type, token_count, force=True)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
|
||||
@@ -6,9 +6,9 @@ Combines multiple scoring strategies with configurable weights.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext
|
||||
@@ -91,11 +91,11 @@ class CompositeScorer:
|
||||
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||
|
||||
# Per-context locks to prevent race conditions during parallel scoring
|
||||
# Uses WeakValueDictionary so locks are garbage collected when not in use
|
||||
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = (
|
||||
WeakValueDictionary()
|
||||
)
|
||||
# Uses dict with (lock, last_used_time) tuples for cleanup
|
||||
self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
|
||||
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
|
||||
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
@@ -141,7 +141,8 @@ class CompositeScorer:
|
||||
Get or create a lock for a specific context.
|
||||
|
||||
Thread-safe access to per-context locks prevents race conditions
|
||||
when the same context is scored concurrently.
|
||||
when the same context is scored concurrently. Includes automatic
|
||||
cleanup of old locks to prevent memory growth.
|
||||
|
||||
Args:
|
||||
context_id: The context ID to get a lock for
|
||||
@@ -149,25 +150,78 @@ class CompositeScorer:
|
||||
Returns:
|
||||
asyncio.Lock for the context
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# Fast path: check if lock exists without acquiring main lock
|
||||
if context_id in self._context_locks:
|
||||
lock = self._context_locks.get(context_id)
|
||||
if lock is not None:
|
||||
# NOTE: We only READ here - no writes to avoid race conditions
|
||||
# with cleanup. The timestamp will be updated in the slow path
|
||||
# if the lock is still valid.
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Return the lock but defer timestamp update to avoid race
|
||||
# The lock is still valid; timestamp update is best-effort
|
||||
return lock
|
||||
|
||||
# Slow path: create lock or update timestamp while holding main lock
|
||||
async with self._locks_lock:
|
||||
# Double-check after acquiring lock - entry may have been
|
||||
# created by another coroutine or deleted by cleanup
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Safe to update timestamp here since we hold the lock
|
||||
self._context_locks[context_id] = (lock, now)
|
||||
return lock
|
||||
|
||||
# Slow path: create lock while holding main lock
|
||||
async with self._locks_lock:
|
||||
# Double-check after acquiring lock
|
||||
if context_id in self._context_locks:
|
||||
lock = self._context_locks.get(context_id)
|
||||
if lock is not None:
|
||||
return lock
|
||||
# Cleanup old locks if we have too many
|
||||
if len(self._context_locks) >= self._max_locks:
|
||||
self._cleanup_old_locks(now)
|
||||
|
||||
# Create new lock
|
||||
new_lock = asyncio.Lock()
|
||||
self._context_locks[context_id] = new_lock
|
||||
self._context_locks[context_id] = (new_lock, now)
|
||||
return new_lock
|
||||
|
||||
def _cleanup_old_locks(self, now: float) -> None:
|
||||
"""
|
||||
Remove old locks that haven't been used recently.
|
||||
|
||||
Called while holding _locks_lock. Removes locks older than _lock_ttl,
|
||||
but only if they're not currently held.
|
||||
|
||||
Args:
|
||||
now: Current timestamp for age calculation
|
||||
"""
|
||||
cutoff = now - self._lock_ttl
|
||||
to_remove = []
|
||||
|
||||
for context_id, (lock, last_used) in self._context_locks.items():
|
||||
# Only remove if old AND not currently held
|
||||
if last_used < cutoff and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
# Remove oldest 50% if still over limit after TTL filtering
|
||||
if len(self._context_locks) - len(to_remove) >= self._max_locks:
|
||||
# Sort by last used time and mark oldest for removal
|
||||
sorted_entries = sorted(
|
||||
self._context_locks.items(),
|
||||
key=lambda x: x[1][1], # Sort by last_used time
|
||||
)
|
||||
# Remove oldest 50% that aren't locked
|
||||
target_remove = len(self._context_locks) // 2
|
||||
for context_id, (lock, _) in sorted_entries:
|
||||
if len(to_remove) >= target_remove:
|
||||
break
|
||||
if context_id not in to_remove and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
for context_id in to_remove:
|
||||
del self._context_locks[context_id]
|
||||
|
||||
if to_remove:
|
||||
logger.debug(f"Cleaned up {len(to_remove)} context locks")
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
|
||||
@@ -72,7 +72,7 @@ class TestContextSettings:
|
||||
"""Test performance settings."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.max_assembly_time_ms == 100
|
||||
assert settings.max_assembly_time_ms == 2000
|
||||
assert settings.parallel_scoring is True
|
||||
assert settings.max_parallel_scores == 10
|
||||
|
||||
|
||||
Reference in New Issue
Block a user