feat(context): improve budget validation and XML safety in ranking and Claude adapter

- Added stricter budget validation in ContextRanker with explicit error handling for invalid configurations.
- Introduced `_get_valid_token_count()` helper to validate and safeguard token counts.
- Enhanced XML escaping in Claude adapter to prevent injection risks from scores and unhandled content.
This commit is contained in:
2026-01-04 16:02:18 +01:00
parent 1628eacf2b
commit 758052dcff
2 changed files with 53 additions and 5 deletions

View File

@@ -90,7 +90,9 @@ class ClaudeAdapter(ModelAdapter):
elif context_type == ContextType.TOOL:
return self._format_tool(contexts)
return "\n".join(c.content for c in contexts)
# Fallback for any unhandled context types - still escape content
# to prevent XML injection if new types are added without updating adapter
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
def _format_system(self, contexts: list[BaseContext]) -> str:
"""Format system contexts."""
@@ -119,7 +121,9 @@ class ClaudeAdapter(ModelAdapter):
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
if score:
parts.append(f'<document source="{source}" relevance="{score}">')
# Escape score to prevent XML injection via metadata
escaped_score = self._escape_xml(str(score))
parts.append(f'<document source="{source}" relevance="{escaped_score}">')
else:
parts.append(f'<document source="{source}">')

View File

@@ -131,9 +131,22 @@ class ContextRanker:
# Calculate the usable budget (total minus reserved portions)
usable_budget = budget.total - budget.response_reserve - budget.buffer
# Guard against invalid budget configuration
if usable_budget <= 0:
raise BudgetExceededError(
message=(
f"Invalid budget configuration: no usable tokens available. "
f"total={budget.total}, response_reserve={budget.response_reserve}, "
f"buffer={budget.buffer}"
),
allocated=budget.total,
requested=0,
context_type="CONFIGURATION_ERROR",
)
# First, try to fit required contexts
for sc in required:
token_count = sc.context.token_count or 0
token_count = self._get_valid_token_count(sc.context)
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
@@ -165,7 +178,7 @@ class ContextRanker:
# Then, greedily add optional contexts
for sc in optional:
token_count = sc.context.token_count or 0
token_count = self._get_valid_token_count(sc.context)
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
@@ -232,13 +245,43 @@ class ContextRanker:
total_tokens = 0
for sc in scored_contexts:
token_count = sc.context.token_count or 0
token_count = self._get_valid_token_count(sc.context)
if total_tokens + token_count <= max_tokens:
selected.append(sc.context)
total_tokens += token_count
return selected
def _get_valid_token_count(self, context: BaseContext) -> int:
"""
Get validated token count from a context.
Ensures token_count is set (not None) and non-negative to prevent
budget bypass attacks where:
- None would be treated as 0 (allowing huge contexts to slip through)
- Negative values would corrupt budget tracking
Args:
context: Context to get token count from
Returns:
Valid non-negative token count
Raises:
ValueError: If token_count is None or negative
"""
if context.token_count is None:
raise ValueError(
f"Context '{context.source}' has no token count. "
"Ensure _ensure_token_counts() is called before ranking."
)
if context.token_count < 0:
raise ValueError(
f"Context '{context.source}' has invalid negative token count: "
f"{context.token_count}"
)
return context.token_count
async def _ensure_token_counts(
self,
contexts: list[BaseContext],
@@ -283,6 +326,7 @@ class ContextRanker:
if type_name not in by_type:
by_type[type_name] = {"count": 0, "tokens": 0}
by_type[type_name]["count"] += 1
# Use validated token count (already validated during ranking)
by_type[type_name]["tokens"] += sc.context.token_count or 0
return by_type