feat(context): enhance timeout handling, tenant isolation, and budget management

- Added timeout enforcement for token counting, scoring, and compression with detailed error handling.
- Introduced tenant isolation in context caching using project and agent identifiers.
- Enhanced budget management with stricter checks for critical context overspending and buffer limitations.
- Optimized per-context locking with cleanup to prevent memory leaks in concurrent environments.
- Updated default assembly timeout settings for improved performance and reliability.
- Improved XML escaping in Claude adapter for safety against injection attacks.
- Standardized token estimation using model-specific ratios.
This commit is contained in:
2026-01-04 15:52:50 +01:00
parent 2bea057fb1
commit 1628eacf2b
10 changed files with 271 additions and 175 deletions

View File

@@ -94,12 +94,13 @@ class ClaudeAdapter(ModelAdapter):
def _format_system(self, contexts: list[BaseContext]) -> str: def _format_system(self, contexts: list[BaseContext]) -> str:
"""Format system contexts.""" """Format system contexts."""
content = "\n\n".join(c.content for c in contexts) # System prompts are typically admin-controlled, but escape for safety
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
return f"<system_instructions>\n{content}\n</system_instructions>" return f"<system_instructions>\n{content}\n</system_instructions>"
def _format_task(self, contexts: list[BaseContext]) -> str: def _format_task(self, contexts: list[BaseContext]) -> str:
"""Format task contexts.""" """Format task contexts."""
content = "\n\n".join(c.content for c in contexts) content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
return f"<current_task>\n{content}\n</current_task>" return f"<current_task>\n{content}\n</current_task>"
def _format_knowledge(self, contexts: list[BaseContext]) -> str: def _format_knowledge(self, contexts: list[BaseContext]) -> str:
@@ -107,12 +108,14 @@ class ClaudeAdapter(ModelAdapter):
Format knowledge contexts as structured documents. Format knowledge contexts as structured documents.
Each knowledge context becomes a document with source attribution. Each knowledge context becomes a document with source attribution.
All content is XML-escaped to prevent injection attacks.
""" """
parts = ["<reference_documents>"] parts = ["<reference_documents>"]
for ctx in contexts: for ctx in contexts:
source = self._escape_xml(ctx.source) source = self._escape_xml(ctx.source)
content = ctx.content # Escape content to prevent XML injection
content = self._escape_xml_content(ctx.content)
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", "")) score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
if score: if score:
@@ -131,13 +134,16 @@ class ClaudeAdapter(ModelAdapter):
Format conversation contexts as message history. Format conversation contexts as message history.
Uses role-based message tags for clear turn delineation. Uses role-based message tags for clear turn delineation.
All content is XML-escaped to prevent prompt injection.
""" """
parts = ["<conversation_history>"] parts = ["<conversation_history>"]
for ctx in contexts: for ctx in contexts:
role = ctx.metadata.get("role", "user") role = self._escape_xml(ctx.metadata.get("role", "user"))
# Escape content to prevent prompt injection via fake XML tags
content = self._escape_xml_content(ctx.content)
parts.append(f'<message role="{role}">') parts.append(f'<message role="{role}">')
parts.append(ctx.content) parts.append(content)
parts.append("</message>") parts.append("</message>")
parts.append("</conversation_history>") parts.append("</conversation_history>")
@@ -148,19 +154,23 @@ class ClaudeAdapter(ModelAdapter):
Format tool contexts as tool results. Format tool contexts as tool results.
Each tool result is wrapped with the tool name. Each tool result is wrapped with the tool name.
All content is XML-escaped to prevent injection.
""" """
parts = ["<tool_results>"] parts = ["<tool_results>"]
for ctx in contexts: for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown") tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
status = ctx.metadata.get("status", "") status = ctx.metadata.get("status", "")
if status: if status:
parts.append(f'<tool_result name="{tool_name}" status="{status}">') parts.append(
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
)
else: else:
parts.append(f'<tool_result name="{tool_name}">') parts.append(f'<tool_result name="{tool_name}">')
parts.append(ctx.content) # Escape content to prevent injection
parts.append(self._escape_xml_content(ctx.content))
parts.append("</tool_result>") parts.append("</tool_result>")
parts.append("</tool_results>") parts.append("</tool_results>")
@@ -176,3 +186,21 @@ class ClaudeAdapter(ModelAdapter):
.replace('"', "&quot;") .replace('"', "&quot;")
.replace("'", "&apos;") .replace("'", "&apos;")
) )
@staticmethod
def _escape_xml_content(text: str) -> str:
"""
Escape XML special characters in element content.
This prevents XML injection attacks where malicious content
could break out of XML tags or inject fake tags for prompt injection.
Only escapes &, <, > since quotes don't need escaping in content.
Args:
text: Content text to escape
Returns:
XML-safe content string
"""
return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")

View File

@@ -12,6 +12,7 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from ..adapters import get_adapter
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
from ..compression.truncation import ContextCompressor from ..compression.truncation import ContextCompressor
from ..config import ContextSettings, get_context_settings from ..config import ContextSettings, get_context_settings
@@ -156,20 +157,42 @@ class ContextPipeline:
else: else:
budget = self._allocator.create_budget_for_model(model) budget = self._allocator.create_budget_for_model(model)
# 1. Count tokens for all contexts # 1. Count tokens for all contexts (with timeout enforcement)
await self._ensure_token_counts(contexts, model) try:
await asyncio.wait_for(
self._ensure_token_counts(contexts, model),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during token counting",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
# Check timeout # Check timeout (handles edge case where operation finished just at limit)
self._check_timeout(start, timeout, "token counting") self._check_timeout(start, timeout, "token counting")
# 2. Score and rank contexts # 2. Score and rank contexts (with timeout enforcement)
scoring_start = time.perf_counter() scoring_start = time.perf_counter()
ranking_result = await self._ranker.rank( try:
contexts=contexts, ranking_result = await asyncio.wait_for(
query=query, self._ranker.rank(
budget=budget, contexts=contexts,
model=model, query=query,
) budget=budget,
model=model,
),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during scoring/ranking",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000 metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
selected_contexts = ranking_result.selected_contexts selected_contexts = ranking_result.selected_contexts
@@ -179,12 +202,23 @@ class ContextPipeline:
# Check timeout # Check timeout
self._check_timeout(start, timeout, "scoring") self._check_timeout(start, timeout, "scoring")
# 3. Compress if needed and enabled # 3. Compress if needed and enabled (with timeout enforcement)
if compress and self._needs_compression(selected_contexts, budget): if compress and self._needs_compression(selected_contexts, budget):
compression_start = time.perf_counter() compression_start = time.perf_counter()
selected_contexts = await self._compressor.compress_contexts( try:
selected_contexts, budget, model selected_contexts = await asyncio.wait_for(
) self._compressor.compress_contexts(
selected_contexts, budget, model
),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during compression",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
metrics.compression_time_ms = ( metrics.compression_time_ms = (
time.perf_counter() - compression_start time.perf_counter() - compression_start
) * 1000 ) * 1000
@@ -280,129 +314,18 @@ class ContextPipeline:
""" """
Format contexts for the target model. Format contexts for the target model.
Groups contexts by type and applies model-specific formatting. Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
to format contexts optimally for each model family.
Args:
contexts: Contexts to format
model: Target model name
Returns:
Formatted context string
""" """
# Group by type adapter = get_adapter(model)
by_type: dict[ContextType, list[BaseContext]] = {} return adapter.format(contexts)
for context in contexts:
ct = context.get_type()
if ct not in by_type:
by_type[ct] = []
by_type[ct].append(context)
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
type_order = [
ContextType.SYSTEM,
ContextType.TASK,
ContextType.KNOWLEDGE,
ContextType.CONVERSATION,
ContextType.TOOL,
]
parts: list[str] = []
for ct in type_order:
if ct in by_type:
formatted = self._format_type(by_type[ct], ct, model)
if formatted:
parts.append(formatted)
return "\n\n".join(parts)
def _format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
model: str,
) -> str:
"""Format contexts of a specific type."""
if not contexts:
return ""
# Check if model prefers XML tags (Claude)
use_xml = "claude" in model.lower()
if context_type == ContextType.SYSTEM:
return self._format_system(contexts, use_xml)
elif context_type == ContextType.TASK:
return self._format_task(contexts, use_xml)
elif context_type == ContextType.KNOWLEDGE:
return self._format_knowledge(contexts, use_xml)
elif context_type == ContextType.CONVERSATION:
return self._format_conversation(contexts, use_xml)
elif context_type == ContextType.TOOL:
return self._format_tool(contexts, use_xml)
return "\n".join(c.content for c in contexts)
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<system_instructions>\n{content}\n</system_instructions>"
return content
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<current_task>\n{content}\n</current_task>"
return f"## Current Task\n\n{content}"
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format knowledge contexts."""
if use_xml:
parts = ["<reference_documents>"]
for ctx in contexts:
parts.append(f'<document source="{ctx.source}">')
parts.append(ctx.content)
parts.append("</document>")
parts.append("</reference_documents>")
return "\n".join(parts)
else:
parts = ["## Reference Documents\n"]
for ctx in contexts:
parts.append(f"### Source: {ctx.source}\n")
parts.append(ctx.content)
parts.append("")
return "\n".join(parts)
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format conversation contexts."""
if use_xml:
parts = ["<conversation_history>"]
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f'<message role="{role}">')
parts.append(ctx.content)
parts.append("</message>")
parts.append("</conversation_history>")
return "\n".join(parts)
else:
parts = []
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f"**{role.upper()}**: {ctx.content}")
return "\n\n".join(parts)
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format tool contexts."""
if use_xml:
parts = ["<tool_results>"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f'<tool_result name="{tool_name}">')
parts.append(ctx.content)
parts.append("</tool_result>")
parts.append("</tool_results>")
return "\n".join(parts)
else:
parts = ["## Recent Tool Results\n"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f"### Tool: {tool_name}\n")
parts.append(f"```\n{ctx.content}\n```")
parts.append("")
return "\n".join(parts)
def _check_timeout( def _check_timeout(
self, self,
@@ -412,9 +335,28 @@ class ContextPipeline:
) -> None: ) -> None:
"""Check if timeout exceeded and raise if so.""" """Check if timeout exceeded and raise if so."""
elapsed_ms = (time.perf_counter() - start) * 1000 elapsed_ms = (time.perf_counter() - start) * 1000
if elapsed_ms > timeout_ms: if elapsed_ms >= timeout_ms:
raise AssemblyTimeoutError( raise AssemblyTimeoutError(
message=f"Context assembly timed out during {phase}", message=f"Context assembly timed out during {phase}",
elapsed_ms=elapsed_ms, elapsed_ms=elapsed_ms,
timeout_ms=timeout_ms, timeout_ms=timeout_ms,
) )
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
"""
Calculate remaining timeout in seconds for asyncio.wait_for.
Returns at least a small positive value to avoid immediate timeout
edge cases with wait_for.
Args:
start: Start time from time.perf_counter()
timeout_ms: Total timeout in milliseconds
Returns:
Remaining timeout in seconds (minimum 0.001)
"""
elapsed_ms = (time.perf_counter() - start) * 1000
remaining_ms = timeout_ms - elapsed_ms
# Return at least 1ms to avoid zero/negative timeout edge cases
return max(remaining_ms / 1000.0, 0.001)

View File

@@ -293,14 +293,18 @@ class BudgetAllocator:
if isinstance(context_type, ContextType): if isinstance(context_type, ContextType):
context_type = context_type.value context_type = context_type.value
# Calculate adjustment (limited by buffer) # Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
if adjustment > 0: if adjustment > 0:
# Taking from buffer # Taking from buffer - limited by available buffer
actual_adjustment = min(adjustment, budget.buffer) actual_adjustment = min(adjustment, budget.buffer)
budget.buffer -= actual_adjustment budget.buffer -= actual_adjustment
else: else:
# Returning to buffer # Returning to buffer - limited by current allocation of target type
actual_adjustment = adjustment current_allocation = budget.get_allocation(context_type)
# Can't return more than current allocation
actual_adjustment = max(adjustment, -current_allocation)
# Add returned tokens back to buffer (adjustment is negative, so subtract)
budget.buffer -= actual_adjustment
# Apply to target type # Apply to target type
if context_type == "system": if context_type == "system":

View File

@@ -95,19 +95,28 @@ class ContextCache:
contexts: list[BaseContext], contexts: list[BaseContext],
query: str, query: str,
model: str, model: str,
project_id: str | None = None,
agent_id: str | None = None,
) -> str: ) -> str:
""" """
Compute a fingerprint for a context assembly request. Compute a fingerprint for a context assembly request.
The fingerprint is based on: The fingerprint is based on:
- Project and agent IDs (for tenant isolation)
- Context content hash and metadata (not full content for performance) - Context content hash and metadata (not full content for performance)
- Query string - Query string
- Target model - Target model
SECURITY: project_id and agent_id MUST be included to prevent
cross-tenant cache pollution. Without these, one tenant could
receive cached contexts from another tenant with the same query.
Args: Args:
contexts: List of contexts contexts: List of contexts
query: Query string query: Query string
model: Model name model: Model name
project_id: Project ID for tenant isolation
agent_id: Agent ID for tenant isolation
Returns: Returns:
32-character hex fingerprint 32-character hex fingerprint
@@ -128,6 +137,9 @@ class ContextCache:
) )
data = { data = {
# CRITICAL: Include tenant identifiers for cache isolation
"project_id": project_id or "",
"agent_id": agent_id or "",
"contexts": context_data, "contexts": context_data,
"query": query, "query": query,
"model": model, "model": model,

View File

@@ -19,6 +19,40 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _estimate_tokens(text: str, model: str | None = None) -> int:
"""
Estimate token count using model-specific character ratios.
Module-level function for reuse across classes. Uses the same ratios
as TokenCalculator for consistency.
Args:
text: Text to estimate tokens for
model: Optional model name for model-specific ratios
Returns:
Estimated token count (minimum 1)
"""
# Model-specific character ratios (chars per token)
model_ratios = {
"claude": 3.5,
"gpt-4": 4.0,
"gpt-3.5": 4.0,
"gemini": 4.0,
}
default_ratio = 4.0
ratio = default_ratio
if model:
model_lower = model.lower()
for model_prefix, model_ratio in model_ratios.items():
if model_prefix in model_lower:
ratio = model_ratio
break
return max(1, int(len(text) / ratio))
@dataclass @dataclass
class TruncationResult: class TruncationResult:
"""Result of truncation operation.""" """Result of truncation operation."""
@@ -284,8 +318,8 @@ class TruncationStrategy:
if self._calculator is not None: if self._calculator is not None:
return await self._calculator.count_tokens(text, model) return await self._calculator.count_tokens(text, model)
# Fallback estimation # Fallback estimation with model-specific ratios
return max(1, len(text) // 4) return _estimate_tokens(text, model)
class ContextCompressor: class ContextCompressor:
@@ -415,4 +449,5 @@ class ContextCompressor:
"""Count tokens using calculator or estimation.""" """Count tokens using calculator or estimation."""
if self._calculator is not None: if self._calculator is not None:
return await self._calculator.count_tokens(text, model) return await self._calculator.count_tokens(text, model)
return max(1, len(text) // 4) # Use model-specific estimation for consistency
return _estimate_tokens(text, model)

View File

@@ -149,10 +149,11 @@ class ContextSettings(BaseSettings):
# Performance settings # Performance settings
max_assembly_time_ms: int = Field( max_assembly_time_ms: int = Field(
default=100, default=2000,
ge=10, ge=10,
le=5000, le=30000,
description="Maximum time for context assembly in milliseconds", description="Maximum time for context assembly in milliseconds. "
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
) )
parallel_scoring: bool = Field( parallel_scoring: bool = Field(
default=True, default=True,

View File

@@ -212,7 +212,10 @@ class ContextEngine:
# Check cache if enabled # Check cache if enabled
fingerprint: str | None = None fingerprint: str | None = None
if use_cache and self._cache.is_enabled: if use_cache and self._cache.is_enabled:
fingerprint = self._cache.compute_fingerprint(contexts, query, model) # Include project_id and agent_id for tenant isolation
fingerprint = self._cache.compute_fingerprint(
contexts, query, model, project_id=project_id, agent_id=agent_id
)
cached = await self._cache.get_assembled(fingerprint) cached = await self._cache.get_assembled(fingerprint)
if cached: if cached:
logger.debug(f"Cache hit for context assembly: {fingerprint}") logger.debug(f"Cache hit for context assembly: {fingerprint}")

View File

@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any
from ..budget import TokenBudget, TokenCalculator from ..budget import TokenBudget, TokenCalculator
from ..config import ContextSettings, get_context_settings from ..config import ContextSettings, get_context_settings
from ..exceptions import BudgetExceededError
from ..scoring.composite import CompositeScorer, ScoredContext from ..scoring.composite import CompositeScorer, ScoredContext
from ..types import BaseContext, ContextPriority from ..types import BaseContext, ContextPriority
@@ -127,6 +128,9 @@ class ContextRanker:
excluded: list[ScoredContext] = [] excluded: list[ScoredContext] = []
total_tokens = 0 total_tokens = 0
# Calculate the usable budget (total minus reserved portions)
usable_budget = budget.total - budget.response_reserve - budget.buffer
# First, try to fit required contexts # First, try to fit required contexts
for sc in required: for sc in required:
token_count = sc.context.token_count or 0 token_count = sc.context.token_count or 0
@@ -137,7 +141,20 @@ class ContextRanker:
selected.append(sc) selected.append(sc)
total_tokens += token_count total_tokens += token_count
else: else:
# Force-fit CRITICAL contexts if needed # Force-fit CRITICAL contexts if needed, but check total budget first
if total_tokens + token_count > usable_budget:
# Even CRITICAL contexts cannot exceed total model context window
raise BudgetExceededError(
message=(
f"CRITICAL contexts exceed total budget. "
f"Context '{sc.context.source}' ({token_count} tokens) "
f"would exceed usable budget of {usable_budget} tokens."
),
allocated=usable_budget,
requested=total_tokens + token_count,
context_type="CRITICAL_OVERFLOW",
)
budget.allocate(context_type, token_count, force=True) budget.allocate(context_type, token_count, force=True)
selected.append(sc) selected.append(sc)
total_tokens += token_count total_tokens += token_count

View File

@@ -6,9 +6,9 @@ Combines multiple scoring strategies with configurable weights.
import asyncio import asyncio
import logging import logging
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary
from ..config import ContextSettings, get_context_settings from ..config import ContextSettings, get_context_settings
from ..types import BaseContext from ..types import BaseContext
@@ -91,11 +91,11 @@ class CompositeScorer:
self._priority_scorer = PriorityScorer(weight=self._priority_weight) self._priority_scorer = PriorityScorer(weight=self._priority_weight)
# Per-context locks to prevent race conditions during parallel scoring # Per-context locks to prevent race conditions during parallel scoring
# Uses WeakValueDictionary so locks are garbage collected when not in use # Uses dict with (lock, last_used_time) tuples for cleanup
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = ( self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
WeakValueDictionary()
)
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""Set MCP manager for semantic scoring.""" """Set MCP manager for semantic scoring."""
@@ -141,7 +141,8 @@ class CompositeScorer:
Get or create a lock for a specific context. Get or create a lock for a specific context.
Thread-safe access to per-context locks prevents race conditions Thread-safe access to per-context locks prevents race conditions
when the same context is scored concurrently. when the same context is scored concurrently. Includes automatic
cleanup of old locks to prevent memory growth.
Args: Args:
context_id: The context ID to get a lock for context_id: The context ID to get a lock for
@@ -149,25 +150,78 @@ class CompositeScorer:
Returns: Returns:
asyncio.Lock for the context asyncio.Lock for the context
""" """
now = time.time()
# Fast path: check if lock exists without acquiring main lock # Fast path: check if lock exists without acquiring main lock
if context_id in self._context_locks: # NOTE: We only READ here - no writes to avoid race conditions
lock = self._context_locks.get(context_id) # with cleanup. The timestamp will be updated in the slow path
if lock is not None: # if the lock is still valid.
lock_entry = self._context_locks.get(context_id)
if lock_entry is not None:
lock, _ = lock_entry
# Return the lock but defer timestamp update to avoid race
# The lock is still valid; timestamp update is best-effort
return lock
# Slow path: create lock or update timestamp while holding main lock
async with self._locks_lock:
# Double-check after acquiring lock - entry may have been
# created by another coroutine or deleted by cleanup
lock_entry = self._context_locks.get(context_id)
if lock_entry is not None:
lock, _ = lock_entry
# Safe to update timestamp here since we hold the lock
self._context_locks[context_id] = (lock, now)
return lock return lock
# Slow path: create lock while holding main lock # Cleanup old locks if we have too many
async with self._locks_lock: if len(self._context_locks) >= self._max_locks:
# Double-check after acquiring lock self._cleanup_old_locks(now)
if context_id in self._context_locks:
lock = self._context_locks.get(context_id)
if lock is not None:
return lock
# Create new lock # Create new lock
new_lock = asyncio.Lock() new_lock = asyncio.Lock()
self._context_locks[context_id] = new_lock self._context_locks[context_id] = (new_lock, now)
return new_lock return new_lock
def _cleanup_old_locks(self, now: float) -> None:
"""
Remove old locks that haven't been used recently.
Called while holding _locks_lock. Removes locks older than _lock_ttl,
but only if they're not currently held.
Args:
now: Current timestamp for age calculation
"""
cutoff = now - self._lock_ttl
to_remove = []
for context_id, (lock, last_used) in self._context_locks.items():
# Only remove if old AND not currently held
if last_used < cutoff and not lock.locked():
to_remove.append(context_id)
# Remove oldest 50% if still over limit after TTL filtering
if len(self._context_locks) - len(to_remove) >= self._max_locks:
# Sort by last used time and mark oldest for removal
sorted_entries = sorted(
self._context_locks.items(),
key=lambda x: x[1][1], # Sort by last_used time
)
# Remove oldest 50% that aren't locked
target_remove = len(self._context_locks) // 2
for context_id, (lock, _) in sorted_entries:
if len(to_remove) >= target_remove:
break
if context_id not in to_remove and not lock.locked():
to_remove.append(context_id)
for context_id in to_remove:
del self._context_locks[context_id]
if to_remove:
logger.debug(f"Cleaned up {len(to_remove)} context locks")
async def score( async def score(
self, self,
context: BaseContext, context: BaseContext,

View File

@@ -72,7 +72,7 @@ class TestContextSettings:
"""Test performance settings.""" """Test performance settings."""
settings = ContextSettings() settings = ContextSettings()
assert settings.max_assembly_time_ms == 100 assert settings.max_assembly_time_ms == 2000
assert settings.parallel_scoring is True assert settings.parallel_scoring is True
assert settings.max_parallel_scores == 10 assert settings.max_parallel_scores == 10