forked from cardosofelipe/fast-next-template
feat(context): implement token budget management (Phase 2)
Add TokenCalculator with LLM Gateway integration for accurate token counting with in-memory caching and fallback character-based estimation. Implement TokenBudget for tracking allocations per context type with budget enforcement, and BudgetAllocator for creating budgets based on model context window sizes. - TokenCalculator: MCP integration, caching, model-specific ratios - TokenBudget: allocation tracking, can_fit/allocate/deallocate/reset - BudgetAllocator: model context sizes, budget creation and adjustment - 35 comprehensive tests covering all budget functionality Part of #61 - Context Management Engine 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
284
backend/app/services/context/budget/calculator.py
Normal file
284
backend/app/services/context/budget/calculator.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Token Calculator for Context Management.
|
||||
|
||||
Provides token counting with caching and fallback estimation.
|
||||
Integrates with LLM Gateway for accurate counts.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounterProtocol(Protocol):
|
||||
"""Protocol for token counting implementations."""
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""Count tokens in text."""
|
||||
...
|
||||
|
||||
|
||||
class TokenCalculator:
|
||||
"""
|
||||
Token calculator with LLM Gateway integration.
|
||||
|
||||
Features:
|
||||
- In-memory caching for repeated text
|
||||
- Fallback to character-based estimation
|
||||
- Model-specific counting when possible
|
||||
|
||||
The calculator uses the LLM Gateway's count_tokens tool
|
||||
for accurate counting, with a local cache to avoid
|
||||
repeated calls for the same content.
|
||||
"""
|
||||
|
||||
# Default characters per token ratio for estimation
|
||||
DEFAULT_CHARS_PER_TOKEN = 4.0
|
||||
|
||||
# Model-specific ratios (more accurate estimation)
|
||||
MODEL_CHAR_RATIOS: dict[str, float] = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
"gemini": 4.0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
project_id: str = "system",
|
||||
agent_id: str = "context-engine",
|
||||
cache_enabled: bool = True,
|
||||
cache_max_size: int = 10000,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize token calculator.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway calls
|
||||
project_id: Project ID for LLM Gateway calls
|
||||
agent_id: Agent ID for LLM Gateway calls
|
||||
cache_enabled: Whether to enable in-memory caching
|
||||
cache_max_size: Maximum cache entries
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._project_id = project_id
|
||||
self._agent_id = agent_id
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache_max_size = cache_max_size
|
||||
|
||||
# In-memory cache: hash(model:text) -> token_count
|
||||
self._cache: dict[str, int] = {}
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
def _get_cache_key(self, text: str, model: str | None) -> str:
|
||||
"""Generate cache key from text and model."""
|
||||
# Use hash for efficient storage
|
||||
content = f"{model or 'default'}:{text}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def _check_cache(self, cache_key: str) -> int | None:
|
||||
"""Check cache for existing count."""
|
||||
if not self._cache_enabled:
|
||||
return None
|
||||
|
||||
if cache_key in self._cache:
|
||||
self._cache_hits += 1
|
||||
return self._cache[cache_key]
|
||||
|
||||
self._cache_misses += 1
|
||||
return None
|
||||
|
||||
def _store_cache(self, cache_key: str, count: int) -> None:
|
||||
"""Store count in cache."""
|
||||
if not self._cache_enabled:
|
||||
return
|
||||
|
||||
# Simple LRU-like eviction: remove oldest entries when full
|
||||
if len(self._cache) >= self._cache_max_size:
|
||||
# Remove first 10% of entries
|
||||
entries_to_remove = self._cache_max_size // 10
|
||||
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._cache[cache_key] = count
|
||||
|
||||
def estimate_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""
|
||||
Estimate token count based on character count.
|
||||
|
||||
This is a fast fallback when LLM Gateway is unavailable.
|
||||
|
||||
Args:
|
||||
text: Text to count
|
||||
model: Optional model for more accurate ratio
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Get model-specific ratio
|
||||
ratio = self.DEFAULT_CHARS_PER_TOKEN
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
|
||||
if model_prefix in model_lower:
|
||||
ratio = model_ratio
|
||||
break
|
||||
|
||||
return max(1, int(len(text) / ratio))
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens in text.
|
||||
|
||||
Uses LLM Gateway for accurate counts with fallback to estimation.
|
||||
|
||||
Args:
|
||||
text: Text to count
|
||||
model: Optional model for accurate counting
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Check cache first
|
||||
cache_key = self._get_cache_key(text, model)
|
||||
cached = self._check_cache(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# Try LLM Gateway
|
||||
if self._mcp is not None:
|
||||
try:
|
||||
result = await self._mcp.call_tool(
|
||||
server="llm-gateway",
|
||||
tool="count_tokens",
|
||||
args={
|
||||
"project_id": self._project_id,
|
||||
"agent_id": self._agent_id,
|
||||
"text": text,
|
||||
"model": model,
|
||||
},
|
||||
)
|
||||
|
||||
# Parse result
|
||||
if result.success and result.data:
|
||||
count = self._parse_token_count(result.data)
|
||||
if count is not None:
|
||||
self._store_cache(cache_key, count)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
|
||||
|
||||
# Fallback to estimation
|
||||
count = self.estimate_tokens(text, model)
|
||||
self._store_cache(cache_key, count)
|
||||
return count
|
||||
|
||||
def _parse_token_count(self, data: Any) -> int | None:
|
||||
"""Parse token count from LLM Gateway response."""
|
||||
if isinstance(data, dict):
|
||||
if "token_count" in data:
|
||||
return int(data["token_count"])
|
||||
if "tokens" in data:
|
||||
return int(data["tokens"])
|
||||
if "count" in data:
|
||||
return int(data["count"])
|
||||
|
||||
if isinstance(data, int):
|
||||
return data
|
||||
|
||||
if isinstance(data, str):
|
||||
# Try to parse from text content
|
||||
try:
|
||||
# Handle {"token_count": 123} or just "123"
|
||||
import json
|
||||
|
||||
parsed = json.loads(data)
|
||||
if isinstance(parsed, dict) and "token_count" in parsed:
|
||||
return int(parsed["token_count"])
|
||||
if isinstance(parsed, int):
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Try direct int conversion
|
||||
try:
|
||||
return int(data)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def count_tokens_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
model: str | None = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Count tokens for multiple texts.
|
||||
|
||||
Efficient batch counting with caching.
|
||||
|
||||
Args:
|
||||
texts: List of texts to count
|
||||
model: Optional model for accurate counting
|
||||
|
||||
Returns:
|
||||
List of token counts (same order as input)
|
||||
"""
|
||||
results: list[int] = []
|
||||
|
||||
for text in texts:
|
||||
count = await self.count_tokens(text, model)
|
||||
results.append(count)
|
||||
|
||||
return results
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the token count cache."""
|
||||
self._cache.clear()
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
total = self._cache_hits + self._cache_misses
|
||||
hit_rate = self._cache_hits / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"enabled": self._cache_enabled,
|
||||
"size": len(self._cache),
|
||||
"max_size": self._cache_max_size,
|
||||
"hits": self._cache_hits,
|
||||
"misses": self._cache_misses,
|
||||
"hit_rate": round(hit_rate, 3),
|
||||
}
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""
|
||||
Set the MCP manager (for lazy initialization).
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager instance
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
Reference in New Issue
Block a user