forked from cardosofelipe/fast-next-template
feat(context): improve budget validation and XML safety in ranking and Claude adapter
- Added stricter budget validation in ContextRanker with explicit error handling for invalid configurations. - Introduced `_get_valid_token_count()` helper to validate and safeguard token counts. - Enhanced XML escaping in Claude adapter to prevent injection risks from scores and unhandled content.
This commit is contained in:
@@ -90,7 +90,9 @@ class ClaudeAdapter(ModelAdapter):
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts)
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
# Fallback for any unhandled context types - still escape content
|
||||
# to prevent XML injection if new types are added without updating adapter
|
||||
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format system contexts."""
|
||||
@@ -119,7 +121,9 @@ class ClaudeAdapter(ModelAdapter):
|
||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||
|
||||
if score:
|
||||
parts.append(f'<document source="{source}" relevance="{score}">')
|
||||
# Escape score to prevent XML injection via metadata
|
||||
escaped_score = self._escape_xml(str(score))
|
||||
parts.append(f'<document source="{source}" relevance="{escaped_score}">')
|
||||
else:
|
||||
parts.append(f'<document source="{source}">')
|
||||
|
||||
|
||||
@@ -131,9 +131,22 @@ class ContextRanker:
|
||||
# Calculate the usable budget (total minus reserved portions)
|
||||
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||
|
||||
# Guard against invalid budget configuration
|
||||
if usable_budget <= 0:
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"Invalid budget configuration: no usable tokens available. "
|
||||
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
||||
f"buffer={budget.buffer}"
|
||||
),
|
||||
allocated=budget.total,
|
||||
requested=0,
|
||||
context_type="CONFIGURATION_ERROR",
|
||||
)
|
||||
|
||||
# First, try to fit required contexts
|
||||
for sc in required:
|
||||
token_count = sc.context.token_count or 0
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
@@ -165,7 +178,7 @@ class ContextRanker:
|
||||
|
||||
# Then, greedily add optional contexts
|
||||
for sc in optional:
|
||||
token_count = sc.context.token_count or 0
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
@@ -232,13 +245,43 @@ class ContextRanker:
|
||||
total_tokens = 0
|
||||
|
||||
for sc in scored_contexts:
|
||||
token_count = sc.context.token_count or 0
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
if total_tokens + token_count <= max_tokens:
|
||||
selected.append(sc.context)
|
||||
total_tokens += token_count
|
||||
|
||||
return selected
|
||||
|
||||
def _get_valid_token_count(self, context: BaseContext) -> int:
|
||||
"""
|
||||
Get validated token count from a context.
|
||||
|
||||
Ensures token_count is set (not None) and non-negative to prevent
|
||||
budget bypass attacks where:
|
||||
- None would be treated as 0 (allowing huge contexts to slip through)
|
||||
- Negative values would corrupt budget tracking
|
||||
|
||||
Args:
|
||||
context: Context to get token count from
|
||||
|
||||
Returns:
|
||||
Valid non-negative token count
|
||||
|
||||
Raises:
|
||||
ValueError: If token_count is None or negative
|
||||
"""
|
||||
if context.token_count is None:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has no token count. "
|
||||
"Ensure _ensure_token_counts() is called before ranking."
|
||||
)
|
||||
if context.token_count < 0:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has invalid negative token count: "
|
||||
f"{context.token_count}"
|
||||
)
|
||||
return context.token_count
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
@@ -283,6 +326,7 @@ class ContextRanker:
|
||||
if type_name not in by_type:
|
||||
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||
by_type[type_name]["count"] += 1
|
||||
# Use validated token count (already validated during ranking)
|
||||
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||
|
||||
return by_type
|
||||
|
||||
Reference in New Issue
Block a user