diff --git a/backend/app/services/context/adapters/claude.py b/backend/app/services/context/adapters/claude.py index 2fc1a4e..31b3ba1 100644 --- a/backend/app/services/context/adapters/claude.py +++ b/backend/app/services/context/adapters/claude.py @@ -94,12 +94,13 @@ class ClaudeAdapter(ModelAdapter): def _format_system(self, contexts: list[BaseContext]) -> str: """Format system contexts.""" - content = "\n\n".join(c.content for c in contexts) + # System prompts are typically admin-controlled, but escape for safety + content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts) return f"\n{content}\n" def _format_task(self, contexts: list[BaseContext]) -> str: """Format task contexts.""" - content = "\n\n".join(c.content for c in contexts) + content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts) return f"\n{content}\n" def _format_knowledge(self, contexts: list[BaseContext]) -> str: @@ -107,12 +108,14 @@ class ClaudeAdapter(ModelAdapter): Format knowledge contexts as structured documents. Each knowledge context becomes a document with source attribution. + All content is XML-escaped to prevent injection attacks. """ parts = [""] for ctx in contexts: source = self._escape_xml(ctx.source) - content = ctx.content + # Escape content to prevent XML injection + content = self._escape_xml_content(ctx.content) score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", "")) if score: @@ -131,13 +134,16 @@ class ClaudeAdapter(ModelAdapter): Format conversation contexts as message history. Uses role-based message tags for clear turn delineation. + All content is XML-escaped to prevent prompt injection. """ parts = [""] for ctx in contexts: - role = ctx.metadata.get("role", "user") + role = self._escape_xml(ctx.metadata.get("role", "user")) + # Escape content to prevent prompt injection via fake XML tags + content = self._escape_xml_content(ctx.content) parts.append(f'') - parts.append(ctx.content) + parts.append(content) parts.append("") parts.append("") @@ -148,19 +154,23 @@ class ClaudeAdapter(ModelAdapter): Format tool contexts as tool results. Each tool result is wrapped with the tool name. + All content is XML-escaped to prevent injection. """ parts = [""] for ctx in contexts: - tool_name = ctx.metadata.get("tool_name", "unknown") + tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown")) status = ctx.metadata.get("status", "") if status: - parts.append(f'') + parts.append( + f'' + ) else: parts.append(f'') - parts.append(ctx.content) + # Escape content to prevent injection + parts.append(self._escape_xml_content(ctx.content)) parts.append("") parts.append("") @@ -176,3 +186,21 @@ class ClaudeAdapter(ModelAdapter): .replace('"', """) .replace("'", "'") ) + + @staticmethod + def _escape_xml_content(text: str) -> str: + """ + Escape XML special characters in element content. + + This prevents XML injection attacks where malicious content + could break out of XML tags or inject fake tags for prompt injection. + + Only escapes &, <, > since quotes don't need escaping in content. + + Args: + text: Content text to escape + + Returns: + XML-safe content string + """ + return text.replace("&", "&").replace("<", "<").replace(">", ">") diff --git a/backend/app/services/context/assembly/pipeline.py b/backend/app/services/context/assembly/pipeline.py index af1c8cf..dd4ad17 100644 --- a/backend/app/services/context/assembly/pipeline.py +++ b/backend/app/services/context/assembly/pipeline.py @@ -12,6 +12,7 @@ from dataclasses import dataclass, field from datetime import UTC, datetime from typing import TYPE_CHECKING, Any +from ..adapters import get_adapter from ..budget import BudgetAllocator, TokenBudget, TokenCalculator from ..compression.truncation import ContextCompressor from ..config import ContextSettings, get_context_settings @@ -156,20 +157,42 @@ class ContextPipeline: else: budget = self._allocator.create_budget_for_model(model) - # 1. Count tokens for all contexts - await self._ensure_token_counts(contexts, model) + # 1. Count tokens for all contexts (with timeout enforcement) + try: + await asyncio.wait_for( + self._ensure_token_counts(contexts, model), + timeout=self._remaining_timeout(start, timeout), + ) + except TimeoutError: + elapsed_ms = (time.perf_counter() - start) * 1000 + raise AssemblyTimeoutError( + message="Context assembly timed out during token counting", + elapsed_ms=elapsed_ms, + timeout_ms=timeout, + ) - # Check timeout + # Check timeout (handles edge case where operation finished just at limit) self._check_timeout(start, timeout, "token counting") - # 2. Score and rank contexts + # 2. Score and rank contexts (with timeout enforcement) scoring_start = time.perf_counter() - ranking_result = await self._ranker.rank( - contexts=contexts, - query=query, - budget=budget, - model=model, - ) + try: + ranking_result = await asyncio.wait_for( + self._ranker.rank( + contexts=contexts, + query=query, + budget=budget, + model=model, + ), + timeout=self._remaining_timeout(start, timeout), + ) + except TimeoutError: + elapsed_ms = (time.perf_counter() - start) * 1000 + raise AssemblyTimeoutError( + message="Context assembly timed out during scoring/ranking", + elapsed_ms=elapsed_ms, + timeout_ms=timeout, + ) metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000 selected_contexts = ranking_result.selected_contexts @@ -179,12 +202,23 @@ class ContextPipeline: # Check timeout self._check_timeout(start, timeout, "scoring") - # 3. Compress if needed and enabled + # 3. Compress if needed and enabled (with timeout enforcement) if compress and self._needs_compression(selected_contexts, budget): compression_start = time.perf_counter() - selected_contexts = await self._compressor.compress_contexts( - selected_contexts, budget, model - ) + try: + selected_contexts = await asyncio.wait_for( + self._compressor.compress_contexts( + selected_contexts, budget, model + ), + timeout=self._remaining_timeout(start, timeout), + ) + except TimeoutError: + elapsed_ms = (time.perf_counter() - start) * 1000 + raise AssemblyTimeoutError( + message="Context assembly timed out during compression", + elapsed_ms=elapsed_ms, + timeout_ms=timeout, + ) metrics.compression_time_ms = ( time.perf_counter() - compression_start ) * 1000 @@ -280,129 +314,18 @@ class ContextPipeline: """ Format contexts for the target model. - Groups contexts by type and applies model-specific formatting. + Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.) + to format contexts optimally for each model family. + + Args: + contexts: Contexts to format + model: Target model name + + Returns: + Formatted context string """ - # Group by type - by_type: dict[ContextType, list[BaseContext]] = {} - for context in contexts: - ct = context.get_type() - if ct not in by_type: - by_type[ct] = [] - by_type[ct].append(context) - - # Order types: System -> Task -> Knowledge -> Conversation -> Tool - type_order = [ - ContextType.SYSTEM, - ContextType.TASK, - ContextType.KNOWLEDGE, - ContextType.CONVERSATION, - ContextType.TOOL, - ] - - parts: list[str] = [] - for ct in type_order: - if ct in by_type: - formatted = self._format_type(by_type[ct], ct, model) - if formatted: - parts.append(formatted) - - return "\n\n".join(parts) - - def _format_type( - self, - contexts: list[BaseContext], - context_type: ContextType, - model: str, - ) -> str: - """Format contexts of a specific type.""" - if not contexts: - return "" - - # Check if model prefers XML tags (Claude) - use_xml = "claude" in model.lower() - - if context_type == ContextType.SYSTEM: - return self._format_system(contexts, use_xml) - elif context_type == ContextType.TASK: - return self._format_task(contexts, use_xml) - elif context_type == ContextType.KNOWLEDGE: - return self._format_knowledge(contexts, use_xml) - elif context_type == ContextType.CONVERSATION: - return self._format_conversation(contexts, use_xml) - elif context_type == ContextType.TOOL: - return self._format_tool(contexts, use_xml) - - return "\n".join(c.content for c in contexts) - - def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str: - """Format system contexts.""" - content = "\n\n".join(c.content for c in contexts) - if use_xml: - return f"\n{content}\n" - return content - - def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str: - """Format task contexts.""" - content = "\n\n".join(c.content for c in contexts) - if use_xml: - return f"\n{content}\n" - return f"## Current Task\n\n{content}" - - def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str: - """Format knowledge contexts.""" - if use_xml: - parts = [""] - for ctx in contexts: - parts.append(f'') - parts.append(ctx.content) - parts.append("") - parts.append("") - return "\n".join(parts) - else: - parts = ["## Reference Documents\n"] - for ctx in contexts: - parts.append(f"### Source: {ctx.source}\n") - parts.append(ctx.content) - parts.append("") - return "\n".join(parts) - - def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str: - """Format conversation contexts.""" - if use_xml: - parts = [""] - for ctx in contexts: - role = ctx.metadata.get("role", "user") - parts.append(f'') - parts.append(ctx.content) - parts.append("") - parts.append("") - return "\n".join(parts) - else: - parts = [] - for ctx in contexts: - role = ctx.metadata.get("role", "user") - parts.append(f"**{role.upper()}**: {ctx.content}") - return "\n\n".join(parts) - - def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str: - """Format tool contexts.""" - if use_xml: - parts = [""] - for ctx in contexts: - tool_name = ctx.metadata.get("tool_name", "unknown") - parts.append(f'') - parts.append(ctx.content) - parts.append("") - parts.append("") - return "\n".join(parts) - else: - parts = ["## Recent Tool Results\n"] - for ctx in contexts: - tool_name = ctx.metadata.get("tool_name", "unknown") - parts.append(f"### Tool: {tool_name}\n") - parts.append(f"```\n{ctx.content}\n```") - parts.append("") - return "\n".join(parts) + adapter = get_adapter(model) + return adapter.format(contexts) def _check_timeout( self, @@ -412,9 +335,28 @@ class ContextPipeline: ) -> None: """Check if timeout exceeded and raise if so.""" elapsed_ms = (time.perf_counter() - start) * 1000 - if elapsed_ms > timeout_ms: + if elapsed_ms >= timeout_ms: raise AssemblyTimeoutError( message=f"Context assembly timed out during {phase}", elapsed_ms=elapsed_ms, timeout_ms=timeout_ms, ) + + def _remaining_timeout(self, start: float, timeout_ms: int) -> float: + """ + Calculate remaining timeout in seconds for asyncio.wait_for. + + Returns at least a small positive value to avoid immediate timeout + edge cases with wait_for. + + Args: + start: Start time from time.perf_counter() + timeout_ms: Total timeout in milliseconds + + Returns: + Remaining timeout in seconds (minimum 0.001) + """ + elapsed_ms = (time.perf_counter() - start) * 1000 + remaining_ms = timeout_ms - elapsed_ms + # Return at least 1ms to avoid zero/negative timeout edge cases + return max(remaining_ms / 1000.0, 0.001) diff --git a/backend/app/services/context/budget/allocator.py b/backend/app/services/context/budget/allocator.py index ee33894..6c9507a 100644 --- a/backend/app/services/context/budget/allocator.py +++ b/backend/app/services/context/budget/allocator.py @@ -293,14 +293,18 @@ class BudgetAllocator: if isinstance(context_type, ContextType): context_type = context_type.value - # Calculate adjustment (limited by buffer) + # Calculate adjustment (limited by buffer for increases, by current allocation for decreases) if adjustment > 0: - # Taking from buffer + # Taking from buffer - limited by available buffer actual_adjustment = min(adjustment, budget.buffer) budget.buffer -= actual_adjustment else: - # Returning to buffer - actual_adjustment = adjustment + # Returning to buffer - limited by current allocation of target type + current_allocation = budget.get_allocation(context_type) + # Can't return more than current allocation + actual_adjustment = max(adjustment, -current_allocation) + # Add returned tokens back to buffer (adjustment is negative, so subtract) + budget.buffer -= actual_adjustment # Apply to target type if context_type == "system": diff --git a/backend/app/services/context/cache/context_cache.py b/backend/app/services/context/cache/context_cache.py index 7b26132..d9e9a80 100644 --- a/backend/app/services/context/cache/context_cache.py +++ b/backend/app/services/context/cache/context_cache.py @@ -95,19 +95,28 @@ class ContextCache: contexts: list[BaseContext], query: str, model: str, + project_id: str | None = None, + agent_id: str | None = None, ) -> str: """ Compute a fingerprint for a context assembly request. The fingerprint is based on: + - Project and agent IDs (for tenant isolation) - Context content hash and metadata (not full content for performance) - Query string - Target model + SECURITY: project_id and agent_id MUST be included to prevent + cross-tenant cache pollution. Without these, one tenant could + receive cached contexts from another tenant with the same query. + Args: contexts: List of contexts query: Query string model: Model name + project_id: Project ID for tenant isolation + agent_id: Agent ID for tenant isolation Returns: 32-character hex fingerprint @@ -128,6 +137,9 @@ class ContextCache: ) data = { + # CRITICAL: Include tenant identifiers for cache isolation + "project_id": project_id or "", + "agent_id": agent_id or "", "contexts": context_data, "query": query, "model": model, diff --git a/backend/app/services/context/compression/truncation.py b/backend/app/services/context/compression/truncation.py index 058a894..d11252b 100644 --- a/backend/app/services/context/compression/truncation.py +++ b/backend/app/services/context/compression/truncation.py @@ -19,6 +19,40 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _estimate_tokens(text: str, model: str | None = None) -> int: + """ + Estimate token count using model-specific character ratios. + + Module-level function for reuse across classes. Uses the same ratios + as TokenCalculator for consistency. + + Args: + text: Text to estimate tokens for + model: Optional model name for model-specific ratios + + Returns: + Estimated token count (minimum 1) + """ + # Model-specific character ratios (chars per token) + model_ratios = { + "claude": 3.5, + "gpt-4": 4.0, + "gpt-3.5": 4.0, + "gemini": 4.0, + } + default_ratio = 4.0 + + ratio = default_ratio + if model: + model_lower = model.lower() + for model_prefix, model_ratio in model_ratios.items(): + if model_prefix in model_lower: + ratio = model_ratio + break + + return max(1, int(len(text) / ratio)) + + @dataclass class TruncationResult: """Result of truncation operation.""" @@ -284,8 +318,8 @@ class TruncationStrategy: if self._calculator is not None: return await self._calculator.count_tokens(text, model) - # Fallback estimation - return max(1, len(text) // 4) + # Fallback estimation with model-specific ratios + return _estimate_tokens(text, model) class ContextCompressor: @@ -415,4 +449,5 @@ class ContextCompressor: """Count tokens using calculator or estimation.""" if self._calculator is not None: return await self._calculator.count_tokens(text, model) - return max(1, len(text) // 4) + # Use model-specific estimation for consistency + return _estimate_tokens(text, model) diff --git a/backend/app/services/context/config.py b/backend/app/services/context/config.py index d95d447..079e439 100644 --- a/backend/app/services/context/config.py +++ b/backend/app/services/context/config.py @@ -149,10 +149,11 @@ class ContextSettings(BaseSettings): # Performance settings max_assembly_time_ms: int = Field( - default=100, + default=2000, ge=10, - le=5000, - description="Maximum time for context assembly in milliseconds", + le=30000, + description="Maximum time for context assembly in milliseconds. " + "Should be high enough to accommodate MCP calls for knowledge retrieval.", ) parallel_scoring: bool = Field( default=True, diff --git a/backend/app/services/context/engine.py b/backend/app/services/context/engine.py index 39a190a..707d570 100644 --- a/backend/app/services/context/engine.py +++ b/backend/app/services/context/engine.py @@ -212,7 +212,10 @@ class ContextEngine: # Check cache if enabled fingerprint: str | None = None if use_cache and self._cache.is_enabled: - fingerprint = self._cache.compute_fingerprint(contexts, query, model) + # Include project_id and agent_id for tenant isolation + fingerprint = self._cache.compute_fingerprint( + contexts, query, model, project_id=project_id, agent_id=agent_id + ) cached = await self._cache.get_assembled(fingerprint) if cached: logger.debug(f"Cache hit for context assembly: {fingerprint}") diff --git a/backend/app/services/context/prioritization/ranker.py b/backend/app/services/context/prioritization/ranker.py index b475b6c..80d6edc 100644 --- a/backend/app/services/context/prioritization/ranker.py +++ b/backend/app/services/context/prioritization/ranker.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any from ..budget import TokenBudget, TokenCalculator from ..config import ContextSettings, get_context_settings +from ..exceptions import BudgetExceededError from ..scoring.composite import CompositeScorer, ScoredContext from ..types import BaseContext, ContextPriority @@ -127,6 +128,9 @@ class ContextRanker: excluded: list[ScoredContext] = [] total_tokens = 0 + # Calculate the usable budget (total minus reserved portions) + usable_budget = budget.total - budget.response_reserve - budget.buffer + # First, try to fit required contexts for sc in required: token_count = sc.context.token_count or 0 @@ -137,7 +141,20 @@ class ContextRanker: selected.append(sc) total_tokens += token_count else: - # Force-fit CRITICAL contexts if needed + # Force-fit CRITICAL contexts if needed, but check total budget first + if total_tokens + token_count > usable_budget: + # Even CRITICAL contexts cannot exceed total model context window + raise BudgetExceededError( + message=( + f"CRITICAL contexts exceed total budget. " + f"Context '{sc.context.source}' ({token_count} tokens) " + f"would exceed usable budget of {usable_budget} tokens." + ), + allocated=usable_budget, + requested=total_tokens + token_count, + context_type="CRITICAL_OVERFLOW", + ) + budget.allocate(context_type, token_count, force=True) selected.append(sc) total_tokens += token_count diff --git a/backend/app/services/context/scoring/composite.py b/backend/app/services/context/scoring/composite.py index a75ebf6..b745bd3 100644 --- a/backend/app/services/context/scoring/composite.py +++ b/backend/app/services/context/scoring/composite.py @@ -6,9 +6,9 @@ Combines multiple scoring strategies with configurable weights. import asyncio import logging +import time from dataclasses import dataclass from typing import TYPE_CHECKING, Any -from weakref import WeakValueDictionary from ..config import ContextSettings, get_context_settings from ..types import BaseContext @@ -91,11 +91,11 @@ class CompositeScorer: self._priority_scorer = PriorityScorer(weight=self._priority_weight) # Per-context locks to prevent race conditions during parallel scoring - # Uses WeakValueDictionary so locks are garbage collected when not in use - self._context_locks: WeakValueDictionary[str, asyncio.Lock] = ( - WeakValueDictionary() - ) + # Uses dict with (lock, last_used_time) tuples for cleanup + self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {} self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access + self._max_locks = 1000 # Maximum locks to keep (prevent memory growth) + self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: """Set MCP manager for semantic scoring.""" @@ -141,7 +141,8 @@ class CompositeScorer: Get or create a lock for a specific context. Thread-safe access to per-context locks prevents race conditions - when the same context is scored concurrently. + when the same context is scored concurrently. Includes automatic + cleanup of old locks to prevent memory growth. Args: context_id: The context ID to get a lock for @@ -149,25 +150,78 @@ class CompositeScorer: Returns: asyncio.Lock for the context """ + now = time.time() + # Fast path: check if lock exists without acquiring main lock - if context_id in self._context_locks: - lock = self._context_locks.get(context_id) - if lock is not None: + # NOTE: We only READ here - no writes to avoid race conditions + # with cleanup. The timestamp will be updated in the slow path + # if the lock is still valid. + lock_entry = self._context_locks.get(context_id) + if lock_entry is not None: + lock, _ = lock_entry + # Return the lock but defer timestamp update to avoid race + # The lock is still valid; timestamp update is best-effort + return lock + + # Slow path: create lock or update timestamp while holding main lock + async with self._locks_lock: + # Double-check after acquiring lock - entry may have been + # created by another coroutine or deleted by cleanup + lock_entry = self._context_locks.get(context_id) + if lock_entry is not None: + lock, _ = lock_entry + # Safe to update timestamp here since we hold the lock + self._context_locks[context_id] = (lock, now) return lock - # Slow path: create lock while holding main lock - async with self._locks_lock: - # Double-check after acquiring lock - if context_id in self._context_locks: - lock = self._context_locks.get(context_id) - if lock is not None: - return lock + # Cleanup old locks if we have too many + if len(self._context_locks) >= self._max_locks: + self._cleanup_old_locks(now) # Create new lock new_lock = asyncio.Lock() - self._context_locks[context_id] = new_lock + self._context_locks[context_id] = (new_lock, now) return new_lock + def _cleanup_old_locks(self, now: float) -> None: + """ + Remove old locks that haven't been used recently. + + Called while holding _locks_lock. Removes locks older than _lock_ttl, + but only if they're not currently held. + + Args: + now: Current timestamp for age calculation + """ + cutoff = now - self._lock_ttl + to_remove = [] + + for context_id, (lock, last_used) in self._context_locks.items(): + # Only remove if old AND not currently held + if last_used < cutoff and not lock.locked(): + to_remove.append(context_id) + + # Remove oldest 50% if still over limit after TTL filtering + if len(self._context_locks) - len(to_remove) >= self._max_locks: + # Sort by last used time and mark oldest for removal + sorted_entries = sorted( + self._context_locks.items(), + key=lambda x: x[1][1], # Sort by last_used time + ) + # Remove oldest 50% that aren't locked + target_remove = len(self._context_locks) // 2 + for context_id, (lock, _) in sorted_entries: + if len(to_remove) >= target_remove: + break + if context_id not in to_remove and not lock.locked(): + to_remove.append(context_id) + + for context_id in to_remove: + del self._context_locks[context_id] + + if to_remove: + logger.debug(f"Cleaned up {len(to_remove)} context locks") + async def score( self, context: BaseContext, diff --git a/backend/tests/services/context/test_config.py b/backend/tests/services/context/test_config.py index f797866..537ffd4 100644 --- a/backend/tests/services/context/test_config.py +++ b/backend/tests/services/context/test_config.py @@ -72,7 +72,7 @@ class TestContextSettings: """Test performance settings.""" settings = ContextSettings() - assert settings.max_assembly_time_ms == 100 + assert settings.max_assembly_time_ms == 2000 assert settings.parallel_scoring is True assert settings.max_parallel_scores == 10