Files
fast-next-template/backend/app/services/context/budget/calculator.py
Felipe Cardoso 2bea057fb1 chore(context): refactor for consistency, optimize formatting, and simplify logic
- Cleaned up unnecessary comments in `__all__` definitions for better readability.
- Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping).
- Simplified conditional expressions and inline comments for context scoring and ranking.
- Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`).
- Removed unused imports and ensured consistent usage across test files.
- Updated `test_score_not_cached_on_context` to clarify caching behavior.
- Improved truncation strategy logic and marker handling.
2026-01-04 15:23:14 +01:00

286 lines
8.4 KiB
Python

"""
Token Calculator for Context Management.
Provides token counting with caching and fallback estimation.
Integrates with LLM Gateway for accurate counts.
"""
import hashlib
import logging
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
class TokenCounterProtocol(Protocol):
"""Protocol for token counting implementations."""
async def count_tokens(
self,
text: str,
model: str | None = None,
) -> int:
"""Count tokens in text."""
...
class TokenCalculator:
"""
Token calculator with LLM Gateway integration.
Features:
- In-memory caching for repeated text
- Fallback to character-based estimation
- Model-specific counting when possible
The calculator uses the LLM Gateway's count_tokens tool
for accurate counting, with a local cache to avoid
repeated calls for the same content.
"""
# Default characters per token ratio for estimation
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
# Model-specific ratios (more accurate estimation)
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
"claude": 3.5,
"gpt-4": 4.0,
"gpt-3.5": 4.0,
"gemini": 4.0,
}
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
project_id: str = "system",
agent_id: str = "context-engine",
cache_enabled: bool = True,
cache_max_size: int = 10000,
) -> None:
"""
Initialize token calculator.
Args:
mcp_manager: MCP client manager for LLM Gateway calls
project_id: Project ID for LLM Gateway calls
agent_id: Agent ID for LLM Gateway calls
cache_enabled: Whether to enable in-memory caching
cache_max_size: Maximum cache entries
"""
self._mcp = mcp_manager
self._project_id = project_id
self._agent_id = agent_id
self._cache_enabled = cache_enabled
self._cache_max_size = cache_max_size
# In-memory cache: hash(model:text) -> token_count
self._cache: dict[str, int] = {}
self._cache_hits = 0
self._cache_misses = 0
def _get_cache_key(self, text: str, model: str | None) -> str:
"""Generate cache key from text and model."""
# Use hash for efficient storage
content = f"{model or 'default'}:{text}"
return hashlib.sha256(content.encode()).hexdigest()[:32]
def _check_cache(self, cache_key: str) -> int | None:
"""Check cache for existing count."""
if not self._cache_enabled:
return None
if cache_key in self._cache:
self._cache_hits += 1
return self._cache[cache_key]
self._cache_misses += 1
return None
def _store_cache(self, cache_key: str, count: int) -> None:
"""Store count in cache."""
if not self._cache_enabled:
return
# Simple LRU-like eviction: remove oldest entries when full
if len(self._cache) >= self._cache_max_size:
# Remove first 10% of entries
entries_to_remove = self._cache_max_size // 10
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
for key in keys_to_remove:
del self._cache[key]
self._cache[cache_key] = count
def estimate_tokens(self, text: str, model: str | None = None) -> int:
"""
Estimate token count based on character count.
This is a fast fallback when LLM Gateway is unavailable.
Args:
text: Text to count
model: Optional model for more accurate ratio
Returns:
Estimated token count
"""
if not text:
return 0
# Get model-specific ratio
ratio = self.DEFAULT_CHARS_PER_TOKEN
if model:
model_lower = model.lower()
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
if model_prefix in model_lower:
ratio = model_ratio
break
return max(1, int(len(text) / ratio))
async def count_tokens(
self,
text: str,
model: str | None = None,
) -> int:
"""
Count tokens in text.
Uses LLM Gateway for accurate counts with fallback to estimation.
Args:
text: Text to count
model: Optional model for accurate counting
Returns:
Token count
"""
if not text:
return 0
# Check cache first
cache_key = self._get_cache_key(text, model)
cached = self._check_cache(cache_key)
if cached is not None:
return cached
# Try LLM Gateway
if self._mcp is not None:
try:
result = await self._mcp.call_tool(
server="llm-gateway",
tool="count_tokens",
args={
"project_id": self._project_id,
"agent_id": self._agent_id,
"text": text,
"model": model,
},
)
# Parse result
if result.success and result.data:
count = self._parse_token_count(result.data)
if count is not None:
self._store_cache(cache_key, count)
return count
except Exception as e:
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
# Fallback to estimation
count = self.estimate_tokens(text, model)
self._store_cache(cache_key, count)
return count
def _parse_token_count(self, data: Any) -> int | None:
"""Parse token count from LLM Gateway response."""
if isinstance(data, dict):
if "token_count" in data:
return int(data["token_count"])
if "tokens" in data:
return int(data["tokens"])
if "count" in data:
return int(data["count"])
if isinstance(data, int):
return data
if isinstance(data, str):
# Try to parse from text content
try:
# Handle {"token_count": 123} or just "123"
import json
parsed = json.loads(data)
if isinstance(parsed, dict) and "token_count" in parsed:
return int(parsed["token_count"])
if isinstance(parsed, int):
return parsed
except (json.JSONDecodeError, ValueError):
# Try direct int conversion
try:
return int(data)
except ValueError:
pass
return None
async def count_tokens_batch(
self,
texts: list[str],
model: str | None = None,
) -> list[int]:
"""
Count tokens for multiple texts.
Efficient batch counting with caching and parallel execution.
Args:
texts: List of texts to count
model: Optional model for accurate counting
Returns:
List of token counts (same order as input)
"""
import asyncio
if not texts:
return []
# Execute all token counts in parallel for better performance
tasks = [self.count_tokens(text, model) for text in texts]
return await asyncio.gather(*tasks)
def clear_cache(self) -> None:
"""Clear the token count cache."""
self._cache.clear()
self._cache_hits = 0
self._cache_misses = 0
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
total = self._cache_hits + self._cache_misses
hit_rate = self._cache_hits / total if total > 0 else 0.0
return {
"enabled": self._cache_enabled,
"size": len(self._cache),
"max_size": self._cache_max_size,
"hits": self._cache_hits,
"misses": self._cache_misses,
"hit_rate": round(hit_rate, 3),
}
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""
Set the MCP manager (for lazy initialization).
Args:
mcp_manager: MCP client manager instance
"""
self._mcp = mcp_manager