feat(context): implement token budget management (Phase 2)

Add TokenCalculator with LLM Gateway integration for accurate token
counting with in-memory caching and fallback character-based estimation.
Implement TokenBudget for tracking allocations per context type with
budget enforcement, and BudgetAllocator for creating budgets based on
model context window sizes.

- TokenCalculator: MCP integration, caching, model-specific ratios
- TokenBudget: allocation tracking, can_fit/allocate/deallocate/reset
- BudgetAllocator: model context sizes, budget creation and adjustment
- 35 comprehensive tests covering all budget functionality

Part of #61 - Context Management Engine

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 02:13:23 +01:00
parent 22ecb5e989
commit dfa75e682e
5 changed files with 1277 additions and 0 deletions

View File

@@ -14,11 +14,18 @@ Usage:
ConversationContext,
TaskContext,
ToolContext,
TokenBudget,
BudgetAllocator,
TokenCalculator,
)
# Get settings
settings = get_context_settings()
# Create budget for a model
allocator = BudgetAllocator(settings)
budget = allocator.create_budget_for_model("claude-3-sonnet")
# Create context instances
system_ctx = SystemContext.create_persona(
name="Code Assistant",
@@ -27,6 +34,13 @@ Usage:
)
"""
# Budget Management
from .budget import (
BudgetAllocator,
TokenBudget,
TokenCalculator,
)
# Configuration
from .config import (
ContextSettings,
@@ -67,6 +81,10 @@ from .types import (
)
__all__ = [
# Budget Management
"BudgetAllocator",
"TokenBudget",
"TokenCalculator",
# Configuration
"ContextSettings",
"get_context_settings",

View File

@@ -3,3 +3,12 @@ Token Budget Management Module.
Provides token counting and budget allocation.
"""
from .allocator import BudgetAllocator, TokenBudget
from .calculator import TokenCalculator
__all__ = [
"BudgetAllocator",
"TokenBudget",
"TokenCalculator",
]

View File

@@ -0,0 +1,433 @@
"""
Token Budget Allocator for Context Management.
Manages token budget allocation across context types.
"""
from dataclasses import dataclass, field
from typing import Any
from ..config import ContextSettings, get_context_settings
from ..exceptions import BudgetExceededError
from ..types import ContextType
@dataclass
class TokenBudget:
"""
Token budget allocation and tracking.
Tracks allocated tokens per context type and
monitors usage to prevent overflows.
"""
# Total budget
total: int
# Allocated per type
system: int = 0
task: int = 0
knowledge: int = 0
conversation: int = 0
tools: int = 0
response_reserve: int = 0
buffer: int = 0
# Usage tracking
used: dict[str, int] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Initialize usage tracking."""
if not self.used:
self.used = {ct.value: 0 for ct in ContextType}
def get_allocation(self, context_type: ContextType | str) -> int:
"""
Get allocated tokens for a context type.
Args:
context_type: Context type to get allocation for
Returns:
Allocated token count
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
allocation_map = {
"system": self.system,
"task": self.task,
"knowledge": self.knowledge,
"conversation": self.conversation,
"tool": self.tools,
}
return allocation_map.get(context_type, 0)
def get_used(self, context_type: ContextType | str) -> int:
"""
Get used tokens for a context type.
Args:
context_type: Context type to check
Returns:
Used token count
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
return self.used.get(context_type, 0)
def remaining(self, context_type: ContextType | str) -> int:
"""
Get remaining tokens for a context type.
Args:
context_type: Context type to check
Returns:
Remaining token count
"""
allocated = self.get_allocation(context_type)
used = self.get_used(context_type)
return max(0, allocated - used)
def total_remaining(self) -> int:
"""
Get total remaining tokens across all types.
Returns:
Total remaining tokens
"""
total_used = sum(self.used.values())
usable = self.total - self.response_reserve - self.buffer
return max(0, usable - total_used)
def total_used(self) -> int:
"""
Get total used tokens.
Returns:
Total used tokens
"""
return sum(self.used.values())
def can_fit(self, context_type: ContextType | str, tokens: int) -> bool:
"""
Check if tokens fit within budget for a type.
Args:
context_type: Context type to check
tokens: Number of tokens to fit
Returns:
True if tokens fit within remaining budget
"""
return tokens <= self.remaining(context_type)
def allocate(
self,
context_type: ContextType | str,
tokens: int,
force: bool = False,
) -> bool:
"""
Allocate (use) tokens from a context type's budget.
Args:
context_type: Context type to allocate from
tokens: Number of tokens to allocate
force: If True, allow exceeding budget
Returns:
True if allocation succeeded
Raises:
BudgetExceededError: If tokens exceed budget and force=False
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
if not force and not self.can_fit(context_type, tokens):
raise BudgetExceededError(
message=f"Token budget exceeded for {context_type}",
allocated=self.get_allocation(context_type),
requested=self.get_used(context_type) + tokens,
context_type=context_type,
)
self.used[context_type] = self.used.get(context_type, 0) + tokens
return True
def deallocate(
self,
context_type: ContextType | str,
tokens: int,
) -> None:
"""
Deallocate (return) tokens to a context type's budget.
Args:
context_type: Context type to return to
tokens: Number of tokens to return
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
current = self.used.get(context_type, 0)
self.used[context_type] = max(0, current - tokens)
def reset(self) -> None:
"""Reset all usage tracking."""
self.used = {ct.value: 0 for ct in ContextType}
def utilization(self, context_type: ContextType | str | None = None) -> float:
"""
Get budget utilization percentage.
Args:
context_type: Specific type or None for total
Returns:
Utilization as a fraction (0.0 to 1.0+)
"""
if context_type is None:
usable = self.total - self.response_reserve - self.buffer
if usable <= 0:
return 0.0
return self.total_used() / usable
allocated = self.get_allocation(context_type)
if allocated <= 0:
return 0.0
return self.get_used(context_type) / allocated
def to_dict(self) -> dict[str, Any]:
"""Convert budget to dictionary."""
return {
"total": self.total,
"allocations": {
"system": self.system,
"task": self.task,
"knowledge": self.knowledge,
"conversation": self.conversation,
"tools": self.tools,
"response_reserve": self.response_reserve,
"buffer": self.buffer,
},
"used": dict(self.used),
"remaining": {
ct.value: self.remaining(ct) for ct in ContextType
},
"total_used": self.total_used(),
"total_remaining": self.total_remaining(),
"utilization": round(self.utilization(), 3),
}
class BudgetAllocator:
"""
Budget allocator for context management.
Creates token budgets based on configuration and
model context window sizes.
"""
def __init__(self, settings: ContextSettings | None = None) -> None:
"""
Initialize budget allocator.
Args:
settings: Context settings (uses default if None)
"""
self._settings = settings or get_context_settings()
def create_budget(
self,
total_tokens: int,
custom_allocations: dict[str, float] | None = None,
) -> TokenBudget:
"""
Create a token budget with allocations.
Args:
total_tokens: Total available tokens
custom_allocations: Optional custom allocation percentages
Returns:
TokenBudget with allocations set
"""
# Use custom or default allocations
if custom_allocations:
alloc = custom_allocations
else:
alloc = self._settings.get_budget_allocation()
return TokenBudget(
total=total_tokens,
system=int(total_tokens * alloc.get("system", 0.05)),
task=int(total_tokens * alloc.get("task", 0.10)),
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
tools=int(total_tokens * alloc.get("tools", 0.05)),
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
)
def adjust_budget(
self,
budget: TokenBudget,
context_type: ContextType | str,
adjustment: int,
) -> TokenBudget:
"""
Adjust a specific allocation in a budget.
Takes tokens from buffer and adds to specified type.
Args:
budget: Budget to adjust
context_type: Type to adjust
adjustment: Positive to increase, negative to decrease
Returns:
Adjusted budget
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
# Calculate adjustment (limited by buffer)
if adjustment > 0:
# Taking from buffer
actual_adjustment = min(adjustment, budget.buffer)
budget.buffer -= actual_adjustment
else:
# Returning to buffer
actual_adjustment = adjustment
# Apply to target type
if context_type == "system":
budget.system = max(0, budget.system + actual_adjustment)
elif context_type == "task":
budget.task = max(0, budget.task + actual_adjustment)
elif context_type == "knowledge":
budget.knowledge = max(0, budget.knowledge + actual_adjustment)
elif context_type == "conversation":
budget.conversation = max(0, budget.conversation + actual_adjustment)
elif context_type == "tool":
budget.tools = max(0, budget.tools + actual_adjustment)
return budget
def rebalance_budget(
self,
budget: TokenBudget,
prioritize: list[ContextType] | None = None,
) -> TokenBudget:
"""
Rebalance budget based on actual usage.
Moves unused allocations to prioritized types.
Args:
budget: Budget to rebalance
prioritize: Types to prioritize (in order)
Returns:
Rebalanced budget
"""
if prioritize is None:
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
# Calculate unused tokens per type
unused: dict[str, int] = {}
for ct in ContextType:
remaining = budget.remaining(ct)
if remaining > 0:
unused[ct.value] = remaining
# Calculate total reclaimable (excluding prioritized types)
prioritize_values = {ct.value for ct in prioritize}
reclaimable = sum(
tokens for ct, tokens in unused.items()
if ct not in prioritize_values
)
# Redistribute to prioritized types that are near capacity
for ct in prioritize:
ct_value = ct.value
utilization = budget.utilization(ct)
if utilization > 0.8: # Near capacity
# Give more tokens from reclaimable pool
bonus = min(reclaimable, budget.get_allocation(ct) // 2)
self.adjust_budget(budget, ct, bonus)
reclaimable -= bonus
if reclaimable <= 0:
break
return budget
def get_model_context_size(self, model: str) -> int:
"""
Get context window size for a model.
Args:
model: Model name
Returns:
Context window size in tokens
"""
# Common model context sizes
context_sizes = {
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-3-5-sonnet": 200000,
"claude-3-5-haiku": 200000,
"claude-opus-4": 200000,
"gpt-4-turbo": 128000,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 16385,
"gemini-1.5-pro": 2000000,
"gemini-1.5-flash": 1000000,
"gemini-2.0-flash": 1000000,
"qwen-plus": 32000,
"qwen-turbo": 8000,
"deepseek-chat": 64000,
"deepseek-reasoner": 64000,
}
# Check exact match first
model_lower = model.lower()
if model_lower in context_sizes:
return context_sizes[model_lower]
# Check prefix match
for model_name, size in context_sizes.items():
if model_lower.startswith(model_name):
return size
# Default fallback
return 8192
def create_budget_for_model(
self,
model: str,
custom_allocations: dict[str, float] | None = None,
) -> TokenBudget:
"""
Create a budget based on model's context window.
Args:
model: Model name
custom_allocations: Optional custom allocation percentages
Returns:
TokenBudget sized for the model
"""
context_size = self.get_model_context_size(model)
return self.create_budget(context_size, custom_allocations)

View File

@@ -0,0 +1,284 @@
"""
Token Calculator for Context Management.
Provides token counting with caching and fallback estimation.
Integrates with LLM Gateway for accurate counts.
"""
import hashlib
import logging
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
class TokenCounterProtocol(Protocol):
"""Protocol for token counting implementations."""
async def count_tokens(
self,
text: str,
model: str | None = None,
) -> int:
"""Count tokens in text."""
...
class TokenCalculator:
"""
Token calculator with LLM Gateway integration.
Features:
- In-memory caching for repeated text
- Fallback to character-based estimation
- Model-specific counting when possible
The calculator uses the LLM Gateway's count_tokens tool
for accurate counting, with a local cache to avoid
repeated calls for the same content.
"""
# Default characters per token ratio for estimation
DEFAULT_CHARS_PER_TOKEN = 4.0
# Model-specific ratios (more accurate estimation)
MODEL_CHAR_RATIOS: dict[str, float] = {
"claude": 3.5,
"gpt-4": 4.0,
"gpt-3.5": 4.0,
"gemini": 4.0,
}
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
project_id: str = "system",
agent_id: str = "context-engine",
cache_enabled: bool = True,
cache_max_size: int = 10000,
) -> None:
"""
Initialize token calculator.
Args:
mcp_manager: MCP client manager for LLM Gateway calls
project_id: Project ID for LLM Gateway calls
agent_id: Agent ID for LLM Gateway calls
cache_enabled: Whether to enable in-memory caching
cache_max_size: Maximum cache entries
"""
self._mcp = mcp_manager
self._project_id = project_id
self._agent_id = agent_id
self._cache_enabled = cache_enabled
self._cache_max_size = cache_max_size
# In-memory cache: hash(model:text) -> token_count
self._cache: dict[str, int] = {}
self._cache_hits = 0
self._cache_misses = 0
def _get_cache_key(self, text: str, model: str | None) -> str:
"""Generate cache key from text and model."""
# Use hash for efficient storage
content = f"{model or 'default'}:{text}"
return hashlib.sha256(content.encode()).hexdigest()[:32]
def _check_cache(self, cache_key: str) -> int | None:
"""Check cache for existing count."""
if not self._cache_enabled:
return None
if cache_key in self._cache:
self._cache_hits += 1
return self._cache[cache_key]
self._cache_misses += 1
return None
def _store_cache(self, cache_key: str, count: int) -> None:
"""Store count in cache."""
if not self._cache_enabled:
return
# Simple LRU-like eviction: remove oldest entries when full
if len(self._cache) >= self._cache_max_size:
# Remove first 10% of entries
entries_to_remove = self._cache_max_size // 10
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
for key in keys_to_remove:
del self._cache[key]
self._cache[cache_key] = count
def estimate_tokens(self, text: str, model: str | None = None) -> int:
"""
Estimate token count based on character count.
This is a fast fallback when LLM Gateway is unavailable.
Args:
text: Text to count
model: Optional model for more accurate ratio
Returns:
Estimated token count
"""
if not text:
return 0
# Get model-specific ratio
ratio = self.DEFAULT_CHARS_PER_TOKEN
if model:
model_lower = model.lower()
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
if model_prefix in model_lower:
ratio = model_ratio
break
return max(1, int(len(text) / ratio))
async def count_tokens(
self,
text: str,
model: str | None = None,
) -> int:
"""
Count tokens in text.
Uses LLM Gateway for accurate counts with fallback to estimation.
Args:
text: Text to count
model: Optional model for accurate counting
Returns:
Token count
"""
if not text:
return 0
# Check cache first
cache_key = self._get_cache_key(text, model)
cached = self._check_cache(cache_key)
if cached is not None:
return cached
# Try LLM Gateway
if self._mcp is not None:
try:
result = await self._mcp.call_tool(
server="llm-gateway",
tool="count_tokens",
args={
"project_id": self._project_id,
"agent_id": self._agent_id,
"text": text,
"model": model,
},
)
# Parse result
if result.success and result.data:
count = self._parse_token_count(result.data)
if count is not None:
self._store_cache(cache_key, count)
return count
except Exception as e:
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
# Fallback to estimation
count = self.estimate_tokens(text, model)
self._store_cache(cache_key, count)
return count
def _parse_token_count(self, data: Any) -> int | None:
"""Parse token count from LLM Gateway response."""
if isinstance(data, dict):
if "token_count" in data:
return int(data["token_count"])
if "tokens" in data:
return int(data["tokens"])
if "count" in data:
return int(data["count"])
if isinstance(data, int):
return data
if isinstance(data, str):
# Try to parse from text content
try:
# Handle {"token_count": 123} or just "123"
import json
parsed = json.loads(data)
if isinstance(parsed, dict) and "token_count" in parsed:
return int(parsed["token_count"])
if isinstance(parsed, int):
return parsed
except (json.JSONDecodeError, ValueError):
# Try direct int conversion
try:
return int(data)
except ValueError:
pass
return None
async def count_tokens_batch(
self,
texts: list[str],
model: str | None = None,
) -> list[int]:
"""
Count tokens for multiple texts.
Efficient batch counting with caching.
Args:
texts: List of texts to count
model: Optional model for accurate counting
Returns:
List of token counts (same order as input)
"""
results: list[int] = []
for text in texts:
count = await self.count_tokens(text, model)
results.append(count)
return results
def clear_cache(self) -> None:
"""Clear the token count cache."""
self._cache.clear()
self._cache_hits = 0
self._cache_misses = 0
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
total = self._cache_hits + self._cache_misses
hit_rate = self._cache_hits / total if total > 0 else 0.0
return {
"enabled": self._cache_enabled,
"size": len(self._cache),
"max_size": self._cache_max_size,
"hits": self._cache_hits,
"misses": self._cache_misses,
"hit_rate": round(hit_rate, 3),
}
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""
Set the MCP manager (for lazy initialization).
Args:
mcp_manager: MCP client manager instance
"""
self._mcp = mcp_manager