""" Token Calculator for Context Management. Provides token counting with caching and fallback estimation. Integrates with LLM Gateway for accurate counts. """ import hashlib import logging from typing import TYPE_CHECKING, Any, ClassVar, Protocol if TYPE_CHECKING: from app.services.mcp.client_manager import MCPClientManager logger = logging.getLogger(__name__) class TokenCounterProtocol(Protocol): """Protocol for token counting implementations.""" async def count_tokens( self, text: str, model: str | None = None, ) -> int: """Count tokens in text.""" ... class TokenCalculator: """ Token calculator with LLM Gateway integration. Features: - In-memory caching for repeated text - Fallback to character-based estimation - Model-specific counting when possible The calculator uses the LLM Gateway's count_tokens tool for accurate counting, with a local cache to avoid repeated calls for the same content. """ # Default characters per token ratio for estimation DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0 # Model-specific ratios (more accurate estimation) MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = { "claude": 3.5, "gpt-4": 4.0, "gpt-3.5": 4.0, "gemini": 4.0, } def __init__( self, mcp_manager: "MCPClientManager | None" = None, project_id: str = "system", agent_id: str = "context-engine", cache_enabled: bool = True, cache_max_size: int = 10000, ) -> None: """ Initialize token calculator. Args: mcp_manager: MCP client manager for LLM Gateway calls project_id: Project ID for LLM Gateway calls agent_id: Agent ID for LLM Gateway calls cache_enabled: Whether to enable in-memory caching cache_max_size: Maximum cache entries """ self._mcp = mcp_manager self._project_id = project_id self._agent_id = agent_id self._cache_enabled = cache_enabled self._cache_max_size = cache_max_size # In-memory cache: hash(model:text) -> token_count self._cache: dict[str, int] = {} self._cache_hits = 0 self._cache_misses = 0 def _get_cache_key(self, text: str, model: str | None) -> str: """Generate cache key from text and model.""" # Use hash for efficient storage content = f"{model or 'default'}:{text}" return hashlib.sha256(content.encode()).hexdigest()[:32] def _check_cache(self, cache_key: str) -> int | None: """Check cache for existing count.""" if not self._cache_enabled: return None if cache_key in self._cache: self._cache_hits += 1 return self._cache[cache_key] self._cache_misses += 1 return None def _store_cache(self, cache_key: str, count: int) -> None: """Store count in cache.""" if not self._cache_enabled: return # Simple LRU-like eviction: remove oldest entries when full if len(self._cache) >= self._cache_max_size: # Remove first 10% of entries entries_to_remove = self._cache_max_size // 10 keys_to_remove = list(self._cache.keys())[:entries_to_remove] for key in keys_to_remove: del self._cache[key] self._cache[cache_key] = count def estimate_tokens(self, text: str, model: str | None = None) -> int: """ Estimate token count based on character count. This is a fast fallback when LLM Gateway is unavailable. Args: text: Text to count model: Optional model for more accurate ratio Returns: Estimated token count """ if not text: return 0 # Get model-specific ratio ratio = self.DEFAULT_CHARS_PER_TOKEN if model: model_lower = model.lower() for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items(): if model_prefix in model_lower: ratio = model_ratio break return max(1, int(len(text) / ratio)) async def count_tokens( self, text: str, model: str | None = None, ) -> int: """ Count tokens in text. Uses LLM Gateway for accurate counts with fallback to estimation. Args: text: Text to count model: Optional model for accurate counting Returns: Token count """ if not text: return 0 # Check cache first cache_key = self._get_cache_key(text, model) cached = self._check_cache(cache_key) if cached is not None: return cached # Try LLM Gateway if self._mcp is not None: try: result = await self._mcp.call_tool( server="llm-gateway", tool="count_tokens", args={ "project_id": self._project_id, "agent_id": self._agent_id, "text": text, "model": model, }, ) # Parse result if result.success and result.data: count = self._parse_token_count(result.data) if count is not None: self._store_cache(cache_key, count) return count except Exception as e: logger.warning(f"LLM Gateway token count failed, using estimation: {e}") # Fallback to estimation count = self.estimate_tokens(text, model) self._store_cache(cache_key, count) return count def _parse_token_count(self, data: Any) -> int | None: """Parse token count from LLM Gateway response.""" if isinstance(data, dict): if "token_count" in data: return int(data["token_count"]) if "tokens" in data: return int(data["tokens"]) if "count" in data: return int(data["count"]) if isinstance(data, int): return data if isinstance(data, str): # Try to parse from text content try: # Handle {"token_count": 123} or just "123" import json parsed = json.loads(data) if isinstance(parsed, dict) and "token_count" in parsed: return int(parsed["token_count"]) if isinstance(parsed, int): return parsed except (json.JSONDecodeError, ValueError): # Try direct int conversion try: return int(data) except ValueError: pass return None async def count_tokens_batch( self, texts: list[str], model: str | None = None, ) -> list[int]: """ Count tokens for multiple texts. Efficient batch counting with caching and parallel execution. Args: texts: List of texts to count model: Optional model for accurate counting Returns: List of token counts (same order as input) """ import asyncio if not texts: return [] # Execute all token counts in parallel for better performance tasks = [self.count_tokens(text, model) for text in texts] return await asyncio.gather(*tasks) def clear_cache(self) -> None: """Clear the token count cache.""" self._cache.clear() self._cache_hits = 0 self._cache_misses = 0 def get_cache_stats(self) -> dict[str, Any]: """Get cache statistics.""" total = self._cache_hits + self._cache_misses hit_rate = self._cache_hits / total if total > 0 else 0.0 return { "enabled": self._cache_enabled, "size": len(self._cache), "max_size": self._cache_max_size, "hits": self._cache_hits, "misses": self._cache_misses, "hit_rate": round(hit_rate, 3), } def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: """ Set the MCP manager (for lazy initialization). Args: mcp_manager: MCP client manager instance """ self._mcp = mcp_manager