From 758052dcff6d53ca0f6db73a253e6b1afc963c5e Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sun, 4 Jan 2026 16:02:18 +0100 Subject: [PATCH] feat(context): improve budget validation and XML safety in ranking and Claude adapter - Added stricter budget validation in ContextRanker with explicit error handling for invalid configurations. - Introduced `_get_valid_token_count()` helper to validate and safeguard token counts. - Enhanced XML escaping in Claude adapter to prevent injection risks from scores and unhandled content. --- .../app/services/context/adapters/claude.py | 8 ++- .../services/context/prioritization/ranker.py | 50 +++++++++++++++++-- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/backend/app/services/context/adapters/claude.py b/backend/app/services/context/adapters/claude.py index 31b3ba1..76b5cf7 100644 --- a/backend/app/services/context/adapters/claude.py +++ b/backend/app/services/context/adapters/claude.py @@ -90,7 +90,9 @@ class ClaudeAdapter(ModelAdapter): elif context_type == ContextType.TOOL: return self._format_tool(contexts) - return "\n".join(c.content for c in contexts) + # Fallback for any unhandled context types - still escape content + # to prevent XML injection if new types are added without updating adapter + return "\n".join(self._escape_xml_content(c.content) for c in contexts) def _format_system(self, contexts: list[BaseContext]) -> str: """Format system contexts.""" @@ -119,7 +121,9 @@ class ClaudeAdapter(ModelAdapter): score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", "")) if score: - parts.append(f'') + # Escape score to prevent XML injection via metadata + escaped_score = self._escape_xml(str(score)) + parts.append(f'') else: parts.append(f'') diff --git a/backend/app/services/context/prioritization/ranker.py b/backend/app/services/context/prioritization/ranker.py index 80d6edc..80930cf 100644 --- a/backend/app/services/context/prioritization/ranker.py +++ b/backend/app/services/context/prioritization/ranker.py @@ -131,9 +131,22 @@ class ContextRanker: # Calculate the usable budget (total minus reserved portions) usable_budget = budget.total - budget.response_reserve - budget.buffer + # Guard against invalid budget configuration + if usable_budget <= 0: + raise BudgetExceededError( + message=( + f"Invalid budget configuration: no usable tokens available. " + f"total={budget.total}, response_reserve={budget.response_reserve}, " + f"buffer={budget.buffer}" + ), + allocated=budget.total, + requested=0, + context_type="CONFIGURATION_ERROR", + ) + # First, try to fit required contexts for sc in required: - token_count = sc.context.token_count or 0 + token_count = self._get_valid_token_count(sc.context) context_type = sc.context.get_type() if budget.can_fit(context_type, token_count): @@ -165,7 +178,7 @@ class ContextRanker: # Then, greedily add optional contexts for sc in optional: - token_count = sc.context.token_count or 0 + token_count = self._get_valid_token_count(sc.context) context_type = sc.context.get_type() if budget.can_fit(context_type, token_count): @@ -232,13 +245,43 @@ class ContextRanker: total_tokens = 0 for sc in scored_contexts: - token_count = sc.context.token_count or 0 + token_count = self._get_valid_token_count(sc.context) if total_tokens + token_count <= max_tokens: selected.append(sc.context) total_tokens += token_count return selected + def _get_valid_token_count(self, context: BaseContext) -> int: + """ + Get validated token count from a context. + + Ensures token_count is set (not None) and non-negative to prevent + budget bypass attacks where: + - None would be treated as 0 (allowing huge contexts to slip through) + - Negative values would corrupt budget tracking + + Args: + context: Context to get token count from + + Returns: + Valid non-negative token count + + Raises: + ValueError: If token_count is None or negative + """ + if context.token_count is None: + raise ValueError( + f"Context '{context.source}' has no token count. " + "Ensure _ensure_token_counts() is called before ranking." + ) + if context.token_count < 0: + raise ValueError( + f"Context '{context.source}' has invalid negative token count: " + f"{context.token_count}" + ) + return context.token_count + async def _ensure_token_counts( self, contexts: list[BaseContext], @@ -283,6 +326,7 @@ class ContextRanker: if type_name not in by_type: by_type[type_name] = {"count": 0, "tokens": 0} by_type[type_name]["count"] += 1 + # Use validated token count (already validated during ranking) by_type[type_name]["tokens"] += sc.context.token_count or 0 return by_type