forked from cardosofelipe/pragma-stack
feat(context): improve budget validation and XML safety in ranking and Claude adapter
- Added stricter budget validation in ContextRanker with explicit error handling for invalid configurations. - Introduced `_get_valid_token_count()` helper to validate and safeguard token counts. - Enhanced XML escaping in Claude adapter to prevent injection risks from scores and unhandled content.
This commit is contained in:
@@ -90,7 +90,9 @@ class ClaudeAdapter(ModelAdapter):
|
|||||||
elif context_type == ContextType.TOOL:
|
elif context_type == ContextType.TOOL:
|
||||||
return self._format_tool(contexts)
|
return self._format_tool(contexts)
|
||||||
|
|
||||||
return "\n".join(c.content for c in contexts)
|
# Fallback for any unhandled context types - still escape content
|
||||||
|
# to prevent XML injection if new types are added without updating adapter
|
||||||
|
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
|
|
||||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||||
"""Format system contexts."""
|
"""Format system contexts."""
|
||||||
@@ -119,7 +121,9 @@ class ClaudeAdapter(ModelAdapter):
|
|||||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||||
|
|
||||||
if score:
|
if score:
|
||||||
parts.append(f'<document source="{source}" relevance="{score}">')
|
# Escape score to prevent XML injection via metadata
|
||||||
|
escaped_score = self._escape_xml(str(score))
|
||||||
|
parts.append(f'<document source="{source}" relevance="{escaped_score}">')
|
||||||
else:
|
else:
|
||||||
parts.append(f'<document source="{source}">')
|
parts.append(f'<document source="{source}">')
|
||||||
|
|
||||||
|
|||||||
@@ -131,9 +131,22 @@ class ContextRanker:
|
|||||||
# Calculate the usable budget (total minus reserved portions)
|
# Calculate the usable budget (total minus reserved portions)
|
||||||
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||||
|
|
||||||
|
# Guard against invalid budget configuration
|
||||||
|
if usable_budget <= 0:
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=(
|
||||||
|
f"Invalid budget configuration: no usable tokens available. "
|
||||||
|
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
||||||
|
f"buffer={budget.buffer}"
|
||||||
|
),
|
||||||
|
allocated=budget.total,
|
||||||
|
requested=0,
|
||||||
|
context_type="CONFIGURATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
# First, try to fit required contexts
|
# First, try to fit required contexts
|
||||||
for sc in required:
|
for sc in required:
|
||||||
token_count = sc.context.token_count or 0
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
context_type = sc.context.get_type()
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
if budget.can_fit(context_type, token_count):
|
if budget.can_fit(context_type, token_count):
|
||||||
@@ -165,7 +178,7 @@ class ContextRanker:
|
|||||||
|
|
||||||
# Then, greedily add optional contexts
|
# Then, greedily add optional contexts
|
||||||
for sc in optional:
|
for sc in optional:
|
||||||
token_count = sc.context.token_count or 0
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
context_type = sc.context.get_type()
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
if budget.can_fit(context_type, token_count):
|
if budget.can_fit(context_type, token_count):
|
||||||
@@ -232,13 +245,43 @@ class ContextRanker:
|
|||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
for sc in scored_contexts:
|
for sc in scored_contexts:
|
||||||
token_count = sc.context.token_count or 0
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
if total_tokens + token_count <= max_tokens:
|
if total_tokens + token_count <= max_tokens:
|
||||||
selected.append(sc.context)
|
selected.append(sc.context)
|
||||||
total_tokens += token_count
|
total_tokens += token_count
|
||||||
|
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
def _get_valid_token_count(self, context: BaseContext) -> int:
|
||||||
|
"""
|
||||||
|
Get validated token count from a context.
|
||||||
|
|
||||||
|
Ensures token_count is set (not None) and non-negative to prevent
|
||||||
|
budget bypass attacks where:
|
||||||
|
- None would be treated as 0 (allowing huge contexts to slip through)
|
||||||
|
- Negative values would corrupt budget tracking
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to get token count from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Valid non-negative token count
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If token_count is None or negative
|
||||||
|
"""
|
||||||
|
if context.token_count is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Context '{context.source}' has no token count. "
|
||||||
|
"Ensure _ensure_token_counts() is called before ranking."
|
||||||
|
)
|
||||||
|
if context.token_count < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Context '{context.source}' has invalid negative token count: "
|
||||||
|
f"{context.token_count}"
|
||||||
|
)
|
||||||
|
return context.token_count
|
||||||
|
|
||||||
async def _ensure_token_counts(
|
async def _ensure_token_counts(
|
||||||
self,
|
self,
|
||||||
contexts: list[BaseContext],
|
contexts: list[BaseContext],
|
||||||
@@ -283,6 +326,7 @@ class ContextRanker:
|
|||||||
if type_name not in by_type:
|
if type_name not in by_type:
|
||||||
by_type[type_name] = {"count": 0, "tokens": 0}
|
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||||
by_type[type_name]["count"] += 1
|
by_type[type_name]["count"] += 1
|
||||||
|
# Use validated token count (already validated during ranking)
|
||||||
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||||
|
|
||||||
return by_type
|
return by_type
|
||||||
|
|||||||
Reference in New Issue
Block a user