- Cleaned up unnecessary comments in `__all__` definitions for better readability. - Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping). - Simplified conditional expressions and inline comments for context scoring and ranking. - Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`). - Removed unused imports and ensured consistent usage across test files. - Updated `test_score_not_cached_on_context` to clarify caching behavior. - Improved truncation strategy logic and marker handling.
286 lines
8.4 KiB
Python
286 lines
8.4 KiB
Python
"""
|
|
Token Calculator for Context Management.
|
|
|
|
Provides token counting with caching and fallback estimation.
|
|
Integrates with LLM Gateway for accurate counts.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.mcp.client_manager import MCPClientManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TokenCounterProtocol(Protocol):
|
|
"""Protocol for token counting implementations."""
|
|
|
|
async def count_tokens(
|
|
self,
|
|
text: str,
|
|
model: str | None = None,
|
|
) -> int:
|
|
"""Count tokens in text."""
|
|
...
|
|
|
|
|
|
class TokenCalculator:
|
|
"""
|
|
Token calculator with LLM Gateway integration.
|
|
|
|
Features:
|
|
- In-memory caching for repeated text
|
|
- Fallback to character-based estimation
|
|
- Model-specific counting when possible
|
|
|
|
The calculator uses the LLM Gateway's count_tokens tool
|
|
for accurate counting, with a local cache to avoid
|
|
repeated calls for the same content.
|
|
"""
|
|
|
|
# Default characters per token ratio for estimation
|
|
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
|
|
|
|
# Model-specific ratios (more accurate estimation)
|
|
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
|
|
"claude": 3.5,
|
|
"gpt-4": 4.0,
|
|
"gpt-3.5": 4.0,
|
|
"gemini": 4.0,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
mcp_manager: "MCPClientManager | None" = None,
|
|
project_id: str = "system",
|
|
agent_id: str = "context-engine",
|
|
cache_enabled: bool = True,
|
|
cache_max_size: int = 10000,
|
|
) -> None:
|
|
"""
|
|
Initialize token calculator.
|
|
|
|
Args:
|
|
mcp_manager: MCP client manager for LLM Gateway calls
|
|
project_id: Project ID for LLM Gateway calls
|
|
agent_id: Agent ID for LLM Gateway calls
|
|
cache_enabled: Whether to enable in-memory caching
|
|
cache_max_size: Maximum cache entries
|
|
"""
|
|
self._mcp = mcp_manager
|
|
self._project_id = project_id
|
|
self._agent_id = agent_id
|
|
self._cache_enabled = cache_enabled
|
|
self._cache_max_size = cache_max_size
|
|
|
|
# In-memory cache: hash(model:text) -> token_count
|
|
self._cache: dict[str, int] = {}
|
|
self._cache_hits = 0
|
|
self._cache_misses = 0
|
|
|
|
def _get_cache_key(self, text: str, model: str | None) -> str:
|
|
"""Generate cache key from text and model."""
|
|
# Use hash for efficient storage
|
|
content = f"{model or 'default'}:{text}"
|
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
|
|
|
def _check_cache(self, cache_key: str) -> int | None:
|
|
"""Check cache for existing count."""
|
|
if not self._cache_enabled:
|
|
return None
|
|
|
|
if cache_key in self._cache:
|
|
self._cache_hits += 1
|
|
return self._cache[cache_key]
|
|
|
|
self._cache_misses += 1
|
|
return None
|
|
|
|
def _store_cache(self, cache_key: str, count: int) -> None:
|
|
"""Store count in cache."""
|
|
if not self._cache_enabled:
|
|
return
|
|
|
|
# Simple LRU-like eviction: remove oldest entries when full
|
|
if len(self._cache) >= self._cache_max_size:
|
|
# Remove first 10% of entries
|
|
entries_to_remove = self._cache_max_size // 10
|
|
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
|
|
for key in keys_to_remove:
|
|
del self._cache[key]
|
|
|
|
self._cache[cache_key] = count
|
|
|
|
def estimate_tokens(self, text: str, model: str | None = None) -> int:
|
|
"""
|
|
Estimate token count based on character count.
|
|
|
|
This is a fast fallback when LLM Gateway is unavailable.
|
|
|
|
Args:
|
|
text: Text to count
|
|
model: Optional model for more accurate ratio
|
|
|
|
Returns:
|
|
Estimated token count
|
|
"""
|
|
if not text:
|
|
return 0
|
|
|
|
# Get model-specific ratio
|
|
ratio = self.DEFAULT_CHARS_PER_TOKEN
|
|
if model:
|
|
model_lower = model.lower()
|
|
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
|
|
if model_prefix in model_lower:
|
|
ratio = model_ratio
|
|
break
|
|
|
|
return max(1, int(len(text) / ratio))
|
|
|
|
async def count_tokens(
|
|
self,
|
|
text: str,
|
|
model: str | None = None,
|
|
) -> int:
|
|
"""
|
|
Count tokens in text.
|
|
|
|
Uses LLM Gateway for accurate counts with fallback to estimation.
|
|
|
|
Args:
|
|
text: Text to count
|
|
model: Optional model for accurate counting
|
|
|
|
Returns:
|
|
Token count
|
|
"""
|
|
if not text:
|
|
return 0
|
|
|
|
# Check cache first
|
|
cache_key = self._get_cache_key(text, model)
|
|
cached = self._check_cache(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# Try LLM Gateway
|
|
if self._mcp is not None:
|
|
try:
|
|
result = await self._mcp.call_tool(
|
|
server="llm-gateway",
|
|
tool="count_tokens",
|
|
args={
|
|
"project_id": self._project_id,
|
|
"agent_id": self._agent_id,
|
|
"text": text,
|
|
"model": model,
|
|
},
|
|
)
|
|
|
|
# Parse result
|
|
if result.success and result.data:
|
|
count = self._parse_token_count(result.data)
|
|
if count is not None:
|
|
self._store_cache(cache_key, count)
|
|
return count
|
|
|
|
except Exception as e:
|
|
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
|
|
|
|
# Fallback to estimation
|
|
count = self.estimate_tokens(text, model)
|
|
self._store_cache(cache_key, count)
|
|
return count
|
|
|
|
def _parse_token_count(self, data: Any) -> int | None:
|
|
"""Parse token count from LLM Gateway response."""
|
|
if isinstance(data, dict):
|
|
if "token_count" in data:
|
|
return int(data["token_count"])
|
|
if "tokens" in data:
|
|
return int(data["tokens"])
|
|
if "count" in data:
|
|
return int(data["count"])
|
|
|
|
if isinstance(data, int):
|
|
return data
|
|
|
|
if isinstance(data, str):
|
|
# Try to parse from text content
|
|
try:
|
|
# Handle {"token_count": 123} or just "123"
|
|
import json
|
|
|
|
parsed = json.loads(data)
|
|
if isinstance(parsed, dict) and "token_count" in parsed:
|
|
return int(parsed["token_count"])
|
|
if isinstance(parsed, int):
|
|
return parsed
|
|
except (json.JSONDecodeError, ValueError):
|
|
# Try direct int conversion
|
|
try:
|
|
return int(data)
|
|
except ValueError:
|
|
pass
|
|
|
|
return None
|
|
|
|
async def count_tokens_batch(
|
|
self,
|
|
texts: list[str],
|
|
model: str | None = None,
|
|
) -> list[int]:
|
|
"""
|
|
Count tokens for multiple texts.
|
|
|
|
Efficient batch counting with caching and parallel execution.
|
|
|
|
Args:
|
|
texts: List of texts to count
|
|
model: Optional model for accurate counting
|
|
|
|
Returns:
|
|
List of token counts (same order as input)
|
|
"""
|
|
import asyncio
|
|
|
|
if not texts:
|
|
return []
|
|
|
|
# Execute all token counts in parallel for better performance
|
|
tasks = [self.count_tokens(text, model) for text in texts]
|
|
return await asyncio.gather(*tasks)
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear the token count cache."""
|
|
self._cache.clear()
|
|
self._cache_hits = 0
|
|
self._cache_misses = 0
|
|
|
|
def get_cache_stats(self) -> dict[str, Any]:
|
|
"""Get cache statistics."""
|
|
total = self._cache_hits + self._cache_misses
|
|
hit_rate = self._cache_hits / total if total > 0 else 0.0
|
|
|
|
return {
|
|
"enabled": self._cache_enabled,
|
|
"size": len(self._cache),
|
|
"max_size": self._cache_max_size,
|
|
"hits": self._cache_hits,
|
|
"misses": self._cache_misses,
|
|
"hit_rate": round(hit_rate, 3),
|
|
}
|
|
|
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
|
"""
|
|
Set the MCP manager (for lazy initialization).
|
|
|
|
Args:
|
|
mcp_manager: MCP client manager instance
|
|
"""
|
|
self._mcp = mcp_manager
|