forked from cardosofelipe/fast-next-template
feat(context): implement token budget management (Phase 2)
Add TokenCalculator with LLM Gateway integration for accurate token counting with in-memory caching and fallback character-based estimation. Implement TokenBudget for tracking allocations per context type with budget enforcement, and BudgetAllocator for creating budgets based on model context window sizes. - TokenCalculator: MCP integration, caching, model-specific ratios - TokenBudget: allocation tracking, can_fit/allocate/deallocate/reset - BudgetAllocator: model context sizes, budget creation and adjustment - 35 comprehensive tests covering all budget functionality Part of #61 - Context Management Engine 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -14,11 +14,18 @@ Usage:
|
|||||||
ConversationContext,
|
ConversationContext,
|
||||||
TaskContext,
|
TaskContext,
|
||||||
ToolContext,
|
ToolContext,
|
||||||
|
TokenBudget,
|
||||||
|
BudgetAllocator,
|
||||||
|
TokenCalculator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get settings
|
# Get settings
|
||||||
settings = get_context_settings()
|
settings = get_context_settings()
|
||||||
|
|
||||||
|
# Create budget for a model
|
||||||
|
allocator = BudgetAllocator(settings)
|
||||||
|
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||||
|
|
||||||
# Create context instances
|
# Create context instances
|
||||||
system_ctx = SystemContext.create_persona(
|
system_ctx = SystemContext.create_persona(
|
||||||
name="Code Assistant",
|
name="Code Assistant",
|
||||||
@@ -27,6 +34,13 @@ Usage:
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Budget Management
|
||||||
|
from .budget import (
|
||||||
|
BudgetAllocator,
|
||||||
|
TokenBudget,
|
||||||
|
TokenCalculator,
|
||||||
|
)
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
from .config import (
|
from .config import (
|
||||||
ContextSettings,
|
ContextSettings,
|
||||||
@@ -67,6 +81,10 @@ from .types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Budget Management
|
||||||
|
"BudgetAllocator",
|
||||||
|
"TokenBudget",
|
||||||
|
"TokenCalculator",
|
||||||
# Configuration
|
# Configuration
|
||||||
"ContextSettings",
|
"ContextSettings",
|
||||||
"get_context_settings",
|
"get_context_settings",
|
||||||
|
|||||||
@@ -3,3 +3,12 @@ Token Budget Management Module.
|
|||||||
|
|
||||||
Provides token counting and budget allocation.
|
Provides token counting and budget allocation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .allocator import BudgetAllocator, TokenBudget
|
||||||
|
from .calculator import TokenCalculator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BudgetAllocator",
|
||||||
|
"TokenBudget",
|
||||||
|
"TokenCalculator",
|
||||||
|
]
|
||||||
|
|||||||
433
backend/app/services/context/budget/allocator.py
Normal file
433
backend/app/services/context/budget/allocator.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
"""
|
||||||
|
Token Budget Allocator for Context Management.
|
||||||
|
|
||||||
|
Manages token budget allocation across context types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import BudgetExceededError
|
||||||
|
from ..types import ContextType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenBudget:
|
||||||
|
"""
|
||||||
|
Token budget allocation and tracking.
|
||||||
|
|
||||||
|
Tracks allocated tokens per context type and
|
||||||
|
monitors usage to prevent overflows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Total budget
|
||||||
|
total: int
|
||||||
|
|
||||||
|
# Allocated per type
|
||||||
|
system: int = 0
|
||||||
|
task: int = 0
|
||||||
|
knowledge: int = 0
|
||||||
|
conversation: int = 0
|
||||||
|
tools: int = 0
|
||||||
|
response_reserve: int = 0
|
||||||
|
buffer: int = 0
|
||||||
|
|
||||||
|
# Usage tracking
|
||||||
|
used: dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Initialize usage tracking."""
|
||||||
|
if not self.used:
|
||||||
|
self.used = {ct.value: 0 for ct in ContextType}
|
||||||
|
|
||||||
|
def get_allocation(self, context_type: ContextType | str) -> int:
|
||||||
|
"""
|
||||||
|
Get allocated tokens for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to get allocation for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Allocated token count
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
allocation_map = {
|
||||||
|
"system": self.system,
|
||||||
|
"task": self.task,
|
||||||
|
"knowledge": self.knowledge,
|
||||||
|
"conversation": self.conversation,
|
||||||
|
"tool": self.tools,
|
||||||
|
}
|
||||||
|
return allocation_map.get(context_type, 0)
|
||||||
|
|
||||||
|
def get_used(self, context_type: ContextType | str) -> int:
|
||||||
|
"""
|
||||||
|
Get used tokens for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Used token count
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
return self.used.get(context_type, 0)
|
||||||
|
|
||||||
|
def remaining(self, context_type: ContextType | str) -> int:
|
||||||
|
"""
|
||||||
|
Get remaining tokens for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining token count
|
||||||
|
"""
|
||||||
|
allocated = self.get_allocation(context_type)
|
||||||
|
used = self.get_used(context_type)
|
||||||
|
return max(0, allocated - used)
|
||||||
|
|
||||||
|
def total_remaining(self) -> int:
|
||||||
|
"""
|
||||||
|
Get total remaining tokens across all types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total remaining tokens
|
||||||
|
"""
|
||||||
|
total_used = sum(self.used.values())
|
||||||
|
usable = self.total - self.response_reserve - self.buffer
|
||||||
|
return max(0, usable - total_used)
|
||||||
|
|
||||||
|
def total_used(self) -> int:
|
||||||
|
"""
|
||||||
|
Get total used tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total used tokens
|
||||||
|
"""
|
||||||
|
return sum(self.used.values())
|
||||||
|
|
||||||
|
def can_fit(self, context_type: ContextType | str, tokens: int) -> bool:
|
||||||
|
"""
|
||||||
|
Check if tokens fit within budget for a type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to check
|
||||||
|
tokens: Number of tokens to fit
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if tokens fit within remaining budget
|
||||||
|
"""
|
||||||
|
return tokens <= self.remaining(context_type)
|
||||||
|
|
||||||
|
def allocate(
|
||||||
|
self,
|
||||||
|
context_type: ContextType | str,
|
||||||
|
tokens: int,
|
||||||
|
force: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Allocate (use) tokens from a context type's budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to allocate from
|
||||||
|
tokens: Number of tokens to allocate
|
||||||
|
force: If True, allow exceeding budget
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if allocation succeeded
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BudgetExceededError: If tokens exceed budget and force=False
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
if not force and not self.can_fit(context_type, tokens):
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=f"Token budget exceeded for {context_type}",
|
||||||
|
allocated=self.get_allocation(context_type),
|
||||||
|
requested=self.get_used(context_type) + tokens,
|
||||||
|
context_type=context_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.used[context_type] = self.used.get(context_type, 0) + tokens
|
||||||
|
return True
|
||||||
|
|
||||||
|
def deallocate(
|
||||||
|
self,
|
||||||
|
context_type: ContextType | str,
|
||||||
|
tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Deallocate (return) tokens to a context type's budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to return to
|
||||||
|
tokens: Number of tokens to return
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
current = self.used.get(context_type, 0)
|
||||||
|
self.used[context_type] = max(0, current - tokens)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset all usage tracking."""
|
||||||
|
self.used = {ct.value: 0 for ct in ContextType}
|
||||||
|
|
||||||
|
def utilization(self, context_type: ContextType | str | None = None) -> float:
|
||||||
|
"""
|
||||||
|
Get budget utilization percentage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Specific type or None for total
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Utilization as a fraction (0.0 to 1.0+)
|
||||||
|
"""
|
||||||
|
if context_type is None:
|
||||||
|
usable = self.total - self.response_reserve - self.buffer
|
||||||
|
if usable <= 0:
|
||||||
|
return 0.0
|
||||||
|
return self.total_used() / usable
|
||||||
|
|
||||||
|
allocated = self.get_allocation(context_type)
|
||||||
|
if allocated <= 0:
|
||||||
|
return 0.0
|
||||||
|
return self.get_used(context_type) / allocated
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert budget to dictionary."""
|
||||||
|
return {
|
||||||
|
"total": self.total,
|
||||||
|
"allocations": {
|
||||||
|
"system": self.system,
|
||||||
|
"task": self.task,
|
||||||
|
"knowledge": self.knowledge,
|
||||||
|
"conversation": self.conversation,
|
||||||
|
"tools": self.tools,
|
||||||
|
"response_reserve": self.response_reserve,
|
||||||
|
"buffer": self.buffer,
|
||||||
|
},
|
||||||
|
"used": dict(self.used),
|
||||||
|
"remaining": {
|
||||||
|
ct.value: self.remaining(ct) for ct in ContextType
|
||||||
|
},
|
||||||
|
"total_used": self.total_used(),
|
||||||
|
"total_remaining": self.total_remaining(),
|
||||||
|
"utilization": round(self.utilization(), 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetAllocator:
|
||||||
|
"""
|
||||||
|
Budget allocator for context management.
|
||||||
|
|
||||||
|
Creates token budgets based on configuration and
|
||||||
|
model context window sizes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: ContextSettings | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize budget allocator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
settings: Context settings (uses default if None)
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
|
||||||
|
def create_budget(
|
||||||
|
self,
|
||||||
|
total_tokens: int,
|
||||||
|
custom_allocations: dict[str, float] | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Create a token budget with allocations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_tokens: Total available tokens
|
||||||
|
custom_allocations: Optional custom allocation percentages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBudget with allocations set
|
||||||
|
"""
|
||||||
|
# Use custom or default allocations
|
||||||
|
if custom_allocations:
|
||||||
|
alloc = custom_allocations
|
||||||
|
else:
|
||||||
|
alloc = self._settings.get_budget_allocation()
|
||||||
|
|
||||||
|
return TokenBudget(
|
||||||
|
total=total_tokens,
|
||||||
|
system=int(total_tokens * alloc.get("system", 0.05)),
|
||||||
|
task=int(total_tokens * alloc.get("task", 0.10)),
|
||||||
|
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
|
||||||
|
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
|
||||||
|
tools=int(total_tokens * alloc.get("tools", 0.05)),
|
||||||
|
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
|
||||||
|
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def adjust_budget(
|
||||||
|
self,
|
||||||
|
budget: TokenBudget,
|
||||||
|
context_type: ContextType | str,
|
||||||
|
adjustment: int,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Adjust a specific allocation in a budget.
|
||||||
|
|
||||||
|
Takes tokens from buffer and adds to specified type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
budget: Budget to adjust
|
||||||
|
context_type: Type to adjust
|
||||||
|
adjustment: Positive to increase, negative to decrease
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adjusted budget
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
# Calculate adjustment (limited by buffer)
|
||||||
|
if adjustment > 0:
|
||||||
|
# Taking from buffer
|
||||||
|
actual_adjustment = min(adjustment, budget.buffer)
|
||||||
|
budget.buffer -= actual_adjustment
|
||||||
|
else:
|
||||||
|
# Returning to buffer
|
||||||
|
actual_adjustment = adjustment
|
||||||
|
|
||||||
|
# Apply to target type
|
||||||
|
if context_type == "system":
|
||||||
|
budget.system = max(0, budget.system + actual_adjustment)
|
||||||
|
elif context_type == "task":
|
||||||
|
budget.task = max(0, budget.task + actual_adjustment)
|
||||||
|
elif context_type == "knowledge":
|
||||||
|
budget.knowledge = max(0, budget.knowledge + actual_adjustment)
|
||||||
|
elif context_type == "conversation":
|
||||||
|
budget.conversation = max(0, budget.conversation + actual_adjustment)
|
||||||
|
elif context_type == "tool":
|
||||||
|
budget.tools = max(0, budget.tools + actual_adjustment)
|
||||||
|
|
||||||
|
return budget
|
||||||
|
|
||||||
|
def rebalance_budget(
|
||||||
|
self,
|
||||||
|
budget: TokenBudget,
|
||||||
|
prioritize: list[ContextType] | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Rebalance budget based on actual usage.
|
||||||
|
|
||||||
|
Moves unused allocations to prioritized types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
budget: Budget to rebalance
|
||||||
|
prioritize: Types to prioritize (in order)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rebalanced budget
|
||||||
|
"""
|
||||||
|
if prioritize is None:
|
||||||
|
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
|
||||||
|
|
||||||
|
# Calculate unused tokens per type
|
||||||
|
unused: dict[str, int] = {}
|
||||||
|
for ct in ContextType:
|
||||||
|
remaining = budget.remaining(ct)
|
||||||
|
if remaining > 0:
|
||||||
|
unused[ct.value] = remaining
|
||||||
|
|
||||||
|
# Calculate total reclaimable (excluding prioritized types)
|
||||||
|
prioritize_values = {ct.value for ct in prioritize}
|
||||||
|
reclaimable = sum(
|
||||||
|
tokens for ct, tokens in unused.items()
|
||||||
|
if ct not in prioritize_values
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redistribute to prioritized types that are near capacity
|
||||||
|
for ct in prioritize:
|
||||||
|
ct_value = ct.value
|
||||||
|
utilization = budget.utilization(ct)
|
||||||
|
|
||||||
|
if utilization > 0.8: # Near capacity
|
||||||
|
# Give more tokens from reclaimable pool
|
||||||
|
bonus = min(reclaimable, budget.get_allocation(ct) // 2)
|
||||||
|
self.adjust_budget(budget, ct, bonus)
|
||||||
|
reclaimable -= bonus
|
||||||
|
|
||||||
|
if reclaimable <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return budget
|
||||||
|
|
||||||
|
def get_model_context_size(self, model: str) -> int:
|
||||||
|
"""
|
||||||
|
Get context window size for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context window size in tokens
|
||||||
|
"""
|
||||||
|
# Common model context sizes
|
||||||
|
context_sizes = {
|
||||||
|
"claude-3-opus": 200000,
|
||||||
|
"claude-3-sonnet": 200000,
|
||||||
|
"claude-3-haiku": 200000,
|
||||||
|
"claude-3-5-sonnet": 200000,
|
||||||
|
"claude-3-5-haiku": 200000,
|
||||||
|
"claude-opus-4": 200000,
|
||||||
|
"gpt-4-turbo": 128000,
|
||||||
|
"gpt-4": 8192,
|
||||||
|
"gpt-4-32k": 32768,
|
||||||
|
"gpt-4o": 128000,
|
||||||
|
"gpt-4o-mini": 128000,
|
||||||
|
"gpt-3.5-turbo": 16385,
|
||||||
|
"gemini-1.5-pro": 2000000,
|
||||||
|
"gemini-1.5-flash": 1000000,
|
||||||
|
"gemini-2.0-flash": 1000000,
|
||||||
|
"qwen-plus": 32000,
|
||||||
|
"qwen-turbo": 8000,
|
||||||
|
"deepseek-chat": 64000,
|
||||||
|
"deepseek-reasoner": 64000,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check exact match first
|
||||||
|
model_lower = model.lower()
|
||||||
|
if model_lower in context_sizes:
|
||||||
|
return context_sizes[model_lower]
|
||||||
|
|
||||||
|
# Check prefix match
|
||||||
|
for model_name, size in context_sizes.items():
|
||||||
|
if model_lower.startswith(model_name):
|
||||||
|
return size
|
||||||
|
|
||||||
|
# Default fallback
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
def create_budget_for_model(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
custom_allocations: dict[str, float] | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Create a budget based on model's context window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
custom_allocations: Optional custom allocation percentages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBudget sized for the model
|
||||||
|
"""
|
||||||
|
context_size = self.get_model_context_size(model)
|
||||||
|
return self.create_budget(context_size, custom_allocations)
|
||||||
284
backend/app/services/context/budget/calculator.py
Normal file
284
backend/app/services/context/budget/calculator.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""
|
||||||
|
Token Calculator for Context Management.
|
||||||
|
|
||||||
|
Provides token counting with caching and fallback estimation.
|
||||||
|
Integrates with LLM Gateway for accurate counts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCounterProtocol(Protocol):
|
||||||
|
"""Protocol for token counting implementations."""
|
||||||
|
|
||||||
|
async def count_tokens(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count tokens in text."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCalculator:
|
||||||
|
"""
|
||||||
|
Token calculator with LLM Gateway integration.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- In-memory caching for repeated text
|
||||||
|
- Fallback to character-based estimation
|
||||||
|
- Model-specific counting when possible
|
||||||
|
|
||||||
|
The calculator uses the LLM Gateway's count_tokens tool
|
||||||
|
for accurate counting, with a local cache to avoid
|
||||||
|
repeated calls for the same content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default characters per token ratio for estimation
|
||||||
|
DEFAULT_CHARS_PER_TOKEN = 4.0
|
||||||
|
|
||||||
|
# Model-specific ratios (more accurate estimation)
|
||||||
|
MODEL_CHAR_RATIOS: dict[str, float] = {
|
||||||
|
"claude": 3.5,
|
||||||
|
"gpt-4": 4.0,
|
||||||
|
"gpt-3.5": 4.0,
|
||||||
|
"gemini": 4.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
project_id: str = "system",
|
||||||
|
agent_id: str = "context-engine",
|
||||||
|
cache_enabled: bool = True,
|
||||||
|
cache_max_size: int = 10000,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize token calculator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager for LLM Gateway calls
|
||||||
|
project_id: Project ID for LLM Gateway calls
|
||||||
|
agent_id: Agent ID for LLM Gateway calls
|
||||||
|
cache_enabled: Whether to enable in-memory caching
|
||||||
|
cache_max_size: Maximum cache entries
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._project_id = project_id
|
||||||
|
self._agent_id = agent_id
|
||||||
|
self._cache_enabled = cache_enabled
|
||||||
|
self._cache_max_size = cache_max_size
|
||||||
|
|
||||||
|
# In-memory cache: hash(model:text) -> token_count
|
||||||
|
self._cache: dict[str, int] = {}
|
||||||
|
self._cache_hits = 0
|
||||||
|
self._cache_misses = 0
|
||||||
|
|
||||||
|
def _get_cache_key(self, text: str, model: str | None) -> str:
|
||||||
|
"""Generate cache key from text and model."""
|
||||||
|
# Use hash for efficient storage
|
||||||
|
content = f"{model or 'default'}:{text}"
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
def _check_cache(self, cache_key: str) -> int | None:
|
||||||
|
"""Check cache for existing count."""
|
||||||
|
if not self._cache_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cache_key in self._cache:
|
||||||
|
self._cache_hits += 1
|
||||||
|
return self._cache[cache_key]
|
||||||
|
|
||||||
|
self._cache_misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _store_cache(self, cache_key: str, count: int) -> None:
|
||||||
|
"""Store count in cache."""
|
||||||
|
if not self._cache_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Simple LRU-like eviction: remove oldest entries when full
|
||||||
|
if len(self._cache) >= self._cache_max_size:
|
||||||
|
# Remove first 10% of entries
|
||||||
|
entries_to_remove = self._cache_max_size // 10
|
||||||
|
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._cache[cache_key] = count
|
||||||
|
|
||||||
|
def estimate_tokens(self, text: str, model: str | None = None) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count based on character count.
|
||||||
|
|
||||||
|
This is a fast fallback when LLM Gateway is unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count
|
||||||
|
model: Optional model for more accurate ratio
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Get model-specific ratio
|
||||||
|
ratio = self.DEFAULT_CHARS_PER_TOKEN
|
||||||
|
if model:
|
||||||
|
model_lower = model.lower()
|
||||||
|
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
|
||||||
|
if model_prefix in model_lower:
|
||||||
|
ratio = model_ratio
|
||||||
|
break
|
||||||
|
|
||||||
|
return max(1, int(len(text) / ratio))
|
||||||
|
|
||||||
|
async def count_tokens(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count tokens in text.
|
||||||
|
|
||||||
|
Uses LLM Gateway for accurate counts with fallback to estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count
|
||||||
|
model: Optional model for accurate counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
cache_key = self._get_cache_key(text, model)
|
||||||
|
cached = self._check_cache(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Try LLM Gateway
|
||||||
|
if self._mcp is not None:
|
||||||
|
try:
|
||||||
|
result = await self._mcp.call_tool(
|
||||||
|
server="llm-gateway",
|
||||||
|
tool="count_tokens",
|
||||||
|
args={
|
||||||
|
"project_id": self._project_id,
|
||||||
|
"agent_id": self._agent_id,
|
||||||
|
"text": text,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse result
|
||||||
|
if result.success and result.data:
|
||||||
|
count = self._parse_token_count(result.data)
|
||||||
|
if count is not None:
|
||||||
|
self._store_cache(cache_key, count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
|
||||||
|
|
||||||
|
# Fallback to estimation
|
||||||
|
count = self.estimate_tokens(text, model)
|
||||||
|
self._store_cache(cache_key, count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _parse_token_count(self, data: Any) -> int | None:
|
||||||
|
"""Parse token count from LLM Gateway response."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if "token_count" in data:
|
||||||
|
return int(data["token_count"])
|
||||||
|
if "tokens" in data:
|
||||||
|
return int(data["tokens"])
|
||||||
|
if "count" in data:
|
||||||
|
return int(data["count"])
|
||||||
|
|
||||||
|
if isinstance(data, int):
|
||||||
|
return data
|
||||||
|
|
||||||
|
if isinstance(data, str):
|
||||||
|
# Try to parse from text content
|
||||||
|
try:
|
||||||
|
# Handle {"token_count": 123} or just "123"
|
||||||
|
import json
|
||||||
|
|
||||||
|
parsed = json.loads(data)
|
||||||
|
if isinstance(parsed, dict) and "token_count" in parsed:
|
||||||
|
return int(parsed["token_count"])
|
||||||
|
if isinstance(parsed, int):
|
||||||
|
return parsed
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
# Try direct int conversion
|
||||||
|
try:
|
||||||
|
return int(data)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def count_tokens_batch(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
model: str | None = None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Count tokens for multiple texts.
|
||||||
|
|
||||||
|
Efficient batch counting with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to count
|
||||||
|
model: Optional model for accurate counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token counts (same order as input)
|
||||||
|
"""
|
||||||
|
results: list[int] = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
count = await self.count_tokens(text, model)
|
||||||
|
results.append(count)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the token count cache."""
|
||||||
|
self._cache.clear()
|
||||||
|
self._cache_hits = 0
|
||||||
|
self._cache_misses = 0
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get cache statistics."""
|
||||||
|
total = self._cache_hits + self._cache_misses
|
||||||
|
hit_rate = self._cache_hits / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": self._cache_enabled,
|
||||||
|
"size": len(self._cache),
|
||||||
|
"max_size": self._cache_max_size,
|
||||||
|
"hits": self._cache_hits,
|
||||||
|
"misses": self._cache_misses,
|
||||||
|
"hit_rate": round(hit_rate, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""
|
||||||
|
Set the MCP manager (for lazy initialization).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager instance
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
533
backend/tests/services/context/test_budget.py
Normal file
533
backend/tests/services/context/test_budget.py
Normal file
@@ -0,0 +1,533 @@
|
|||||||
|
"""Tests for token budget management."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.budget import (
|
||||||
|
BudgetAllocator,
|
||||||
|
TokenBudget,
|
||||||
|
TokenCalculator,
|
||||||
|
)
|
||||||
|
from app.services.context.config import ContextSettings
|
||||||
|
from app.services.context.exceptions import BudgetExceededError
|
||||||
|
from app.services.context.types import ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenBudget:
|
||||||
|
"""Tests for TokenBudget dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic budget creation."""
|
||||||
|
budget = TokenBudget(total=10000)
|
||||||
|
assert budget.total == 10000
|
||||||
|
assert budget.system == 0
|
||||||
|
assert budget.total_used() == 0
|
||||||
|
|
||||||
|
def test_creation_with_allocations(self) -> None:
|
||||||
|
"""Test budget creation with allocations."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
task=1000,
|
||||||
|
knowledge=4000,
|
||||||
|
conversation=2000,
|
||||||
|
tools=500,
|
||||||
|
response_reserve=1500,
|
||||||
|
buffer=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert budget.system == 500
|
||||||
|
assert budget.knowledge == 4000
|
||||||
|
assert budget.response_reserve == 1500
|
||||||
|
|
||||||
|
def test_get_allocation(self) -> None:
|
||||||
|
"""Test getting allocation for a type."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
knowledge=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert budget.get_allocation(ContextType.SYSTEM) == 500
|
||||||
|
assert budget.get_allocation(ContextType.KNOWLEDGE) == 4000
|
||||||
|
assert budget.get_allocation("system") == 500
|
||||||
|
|
||||||
|
def test_remaining(self) -> None:
|
||||||
|
"""Test remaining budget calculation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
knowledge=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initially full
|
||||||
|
assert budget.remaining(ContextType.SYSTEM) == 500
|
||||||
|
assert budget.remaining(ContextType.KNOWLEDGE) == 4000
|
||||||
|
|
||||||
|
# After allocation
|
||||||
|
budget.allocate(ContextType.SYSTEM, 200)
|
||||||
|
assert budget.remaining(ContextType.SYSTEM) == 300
|
||||||
|
|
||||||
|
def test_can_fit(self) -> None:
|
||||||
|
"""Test can_fit check."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
knowledge=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert budget.can_fit(ContextType.SYSTEM, 500) is True
|
||||||
|
assert budget.can_fit(ContextType.SYSTEM, 501) is False
|
||||||
|
assert budget.can_fit(ContextType.KNOWLEDGE, 4000) is True
|
||||||
|
|
||||||
|
def test_allocate_success(self) -> None:
|
||||||
|
"""Test successful allocation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = budget.allocate(ContextType.SYSTEM, 200)
|
||||||
|
assert result is True
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 200
|
||||||
|
assert budget.remaining(ContextType.SYSTEM) == 300
|
||||||
|
|
||||||
|
def test_allocate_exceeds_budget(self) -> None:
|
||||||
|
"""Test allocation exceeding budget."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(BudgetExceededError) as exc_info:
|
||||||
|
budget.allocate(ContextType.SYSTEM, 600)
|
||||||
|
|
||||||
|
assert exc_info.value.allocated == 500
|
||||||
|
assert exc_info.value.requested == 600
|
||||||
|
|
||||||
|
def test_allocate_force(self) -> None:
|
||||||
|
"""Test forced allocation exceeding budget."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force should allow exceeding
|
||||||
|
result = budget.allocate(ContextType.SYSTEM, 600, force=True)
|
||||||
|
assert result is True
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 600
|
||||||
|
|
||||||
|
def test_deallocate(self) -> None:
|
||||||
|
"""Test deallocation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(ContextType.SYSTEM, 300)
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 300
|
||||||
|
|
||||||
|
budget.deallocate(ContextType.SYSTEM, 100)
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 200
|
||||||
|
|
||||||
|
def test_deallocate_below_zero(self) -> None:
|
||||||
|
"""Test deallocation doesn't go below zero."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(ContextType.SYSTEM, 100)
|
||||||
|
budget.deallocate(ContextType.SYSTEM, 200)
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 0
|
||||||
|
|
||||||
|
def test_total_remaining(self) -> None:
|
||||||
|
"""Test total remaining calculation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
knowledge=4000,
|
||||||
|
response_reserve=1500,
|
||||||
|
buffer=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Usable = total - response_reserve - buffer = 10000 - 1500 - 500 = 8000
|
||||||
|
assert budget.total_remaining() == 8000
|
||||||
|
|
||||||
|
# After allocation
|
||||||
|
budget.allocate(ContextType.SYSTEM, 200)
|
||||||
|
assert budget.total_remaining() == 7800
|
||||||
|
|
||||||
|
def test_utilization(self) -> None:
|
||||||
|
"""Test utilization calculation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
response_reserve=1500,
|
||||||
|
buffer=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No usage = 0%
|
||||||
|
assert budget.utilization(ContextType.SYSTEM) == 0.0
|
||||||
|
|
||||||
|
# Half used = 50%
|
||||||
|
budget.allocate(ContextType.SYSTEM, 250)
|
||||||
|
assert budget.utilization(ContextType.SYSTEM) == 0.5
|
||||||
|
|
||||||
|
# Total utilization
|
||||||
|
assert budget.utilization() == 250 / 8000 # 250 / (10000 - 1500 - 500)
|
||||||
|
|
||||||
|
def test_reset(self) -> None:
|
||||||
|
"""Test reset clears usage."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(ContextType.SYSTEM, 300)
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 300
|
||||||
|
|
||||||
|
budget.reset()
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 0
|
||||||
|
assert budget.total_used() == 0
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test to_dict conversion."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
task=1000,
|
||||||
|
knowledge=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(ContextType.SYSTEM, 200)
|
||||||
|
|
||||||
|
data = budget.to_dict()
|
||||||
|
assert data["total"] == 10000
|
||||||
|
assert data["allocations"]["system"] == 500
|
||||||
|
assert data["used"]["system"] == 200
|
||||||
|
assert data["remaining"]["system"] == 300
|
||||||
|
|
||||||
|
|
||||||
|
class TestBudgetAllocator:
|
||||||
|
"""Tests for BudgetAllocator."""
|
||||||
|
|
||||||
|
def test_create_budget(self) -> None:
|
||||||
|
"""Test budget creation with default allocations."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(100000)
|
||||||
|
|
||||||
|
assert budget.total == 100000
|
||||||
|
assert budget.system == 5000 # 5%
|
||||||
|
assert budget.task == 10000 # 10%
|
||||||
|
assert budget.knowledge == 40000 # 40%
|
||||||
|
assert budget.conversation == 20000 # 20%
|
||||||
|
assert budget.tools == 5000 # 5%
|
||||||
|
assert budget.response_reserve == 15000 # 15%
|
||||||
|
assert budget.buffer == 5000 # 5%
|
||||||
|
|
||||||
|
def test_create_budget_custom_allocations(self) -> None:
|
||||||
|
"""Test budget creation with custom allocations."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(
|
||||||
|
100000,
|
||||||
|
custom_allocations={
|
||||||
|
"system": 0.10,
|
||||||
|
"task": 0.10,
|
||||||
|
"knowledge": 0.30,
|
||||||
|
"conversation": 0.25,
|
||||||
|
"tools": 0.05,
|
||||||
|
"response": 0.15,
|
||||||
|
"buffer": 0.05,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert budget.system == 10000 # 10%
|
||||||
|
assert budget.knowledge == 30000 # 30%
|
||||||
|
|
||||||
|
def test_create_budget_for_model(self) -> None:
|
||||||
|
"""Test budget creation for specific model."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
|
||||||
|
# Claude models have 200k context
|
||||||
|
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||||
|
assert budget.total == 200000
|
||||||
|
|
||||||
|
# GPT-4 has 8k context
|
||||||
|
budget = allocator.create_budget_for_model("gpt-4")
|
||||||
|
assert budget.total == 8192
|
||||||
|
|
||||||
|
# GPT-4-turbo has 128k context
|
||||||
|
budget = allocator.create_budget_for_model("gpt-4-turbo")
|
||||||
|
assert budget.total == 128000
|
||||||
|
|
||||||
|
def test_get_model_context_size(self) -> None:
|
||||||
|
"""Test model context size lookup."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
|
||||||
|
# Known models
|
||||||
|
assert allocator.get_model_context_size("claude-3-opus") == 200000
|
||||||
|
assert allocator.get_model_context_size("gpt-4") == 8192
|
||||||
|
assert allocator.get_model_context_size("gemini-1.5-pro") == 2000000
|
||||||
|
|
||||||
|
# Unknown model gets default
|
||||||
|
assert allocator.get_model_context_size("unknown-model") == 8192
|
||||||
|
|
||||||
|
def test_adjust_budget(self) -> None:
|
||||||
|
"""Test budget adjustment."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
original_system = budget.system
|
||||||
|
original_buffer = budget.buffer
|
||||||
|
|
||||||
|
# Increase system by taking from buffer
|
||||||
|
budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 200)
|
||||||
|
|
||||||
|
assert budget.system == original_system + 200
|
||||||
|
assert budget.buffer == original_buffer - 200
|
||||||
|
|
||||||
|
def test_adjust_budget_limited_by_buffer(self) -> None:
|
||||||
|
"""Test that adjustment is limited by buffer size."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
original_buffer = budget.buffer
|
||||||
|
|
||||||
|
# Try to increase more than buffer allows
|
||||||
|
budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 10000)
|
||||||
|
|
||||||
|
# Should only increase by buffer amount
|
||||||
|
assert budget.buffer == 0
|
||||||
|
assert budget.system <= original_buffer + budget.system
|
||||||
|
|
||||||
|
def test_rebalance_budget(self) -> None:
|
||||||
|
"""Test budget rebalancing."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
# Use most of knowledge budget
|
||||||
|
budget.allocate(ContextType.KNOWLEDGE, 3500)
|
||||||
|
|
||||||
|
# Rebalance prioritizing knowledge
|
||||||
|
budget = allocator.rebalance_budget(
|
||||||
|
budget,
|
||||||
|
prioritize=[ContextType.KNOWLEDGE],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Knowledge should have gotten more tokens
|
||||||
|
# (This is a fuzzy test - just check it runs)
|
||||||
|
assert budget is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenCalculator:
|
||||||
|
"""Tests for TokenCalculator."""
|
||||||
|
|
||||||
|
def test_estimate_tokens(self) -> None:
|
||||||
|
"""Test token estimation."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
# Empty string
|
||||||
|
assert calc.estimate_tokens("") == 0
|
||||||
|
|
||||||
|
# Short text (~4 chars per token)
|
||||||
|
text = "This is a test message"
|
||||||
|
estimate = calc.estimate_tokens(text)
|
||||||
|
assert 4 <= estimate <= 8
|
||||||
|
|
||||||
|
def test_estimate_tokens_model_specific(self) -> None:
|
||||||
|
"""Test model-specific estimation ratios."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
text = "a" * 100
|
||||||
|
|
||||||
|
# Claude uses 3.5 chars per token
|
||||||
|
claude_estimate = calc.estimate_tokens(text, "claude-3-sonnet")
|
||||||
|
# GPT uses 4.0 chars per token
|
||||||
|
gpt_estimate = calc.estimate_tokens(text, "gpt-4")
|
||||||
|
|
||||||
|
# Claude should estimate more tokens (smaller ratio)
|
||||||
|
assert claude_estimate >= gpt_estimate
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_no_mcp(self) -> None:
|
||||||
|
"""Test token counting without MCP (fallback to estimation)."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
text = "This is a test"
|
||||||
|
count = await calc.count_tokens(text)
|
||||||
|
|
||||||
|
# Should use estimation
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_with_mcp_success(self) -> None:
|
||||||
|
"""Test token counting with MCP integration."""
|
||||||
|
# Mock MCP manager
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.success = True
|
||||||
|
mock_result.data = {"token_count": 42}
|
||||||
|
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
calc = TokenCalculator(mcp_manager=mock_mcp)
|
||||||
|
count = await calc.count_tokens("test text")
|
||||||
|
|
||||||
|
assert count == 42
|
||||||
|
mock_mcp.call_tool.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_with_mcp_failure(self) -> None:
|
||||||
|
"""Test fallback when MCP fails."""
|
||||||
|
# Mock MCP manager that fails
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
|
||||||
|
|
||||||
|
calc = TokenCalculator(mcp_manager=mock_mcp)
|
||||||
|
count = await calc.count_tokens("test text")
|
||||||
|
|
||||||
|
# Should fall back to estimation
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_caching(self) -> None:
|
||||||
|
"""Test that token counts are cached."""
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.success = True
|
||||||
|
mock_result.data = {"token_count": 42}
|
||||||
|
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
calc = TokenCalculator(mcp_manager=mock_mcp)
|
||||||
|
|
||||||
|
# First call
|
||||||
|
count1 = await calc.count_tokens("test text")
|
||||||
|
# Second call (should use cache)
|
||||||
|
count2 = await calc.count_tokens("test text")
|
||||||
|
|
||||||
|
assert count1 == count2 == 42
|
||||||
|
# MCP should only be called once
|
||||||
|
assert mock_mcp.call_tool.call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_batch(self) -> None:
|
||||||
|
"""Test batch token counting."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
texts = ["Hello", "World", "Test message here"]
|
||||||
|
counts = await calc.count_tokens_batch(texts)
|
||||||
|
|
||||||
|
assert len(counts) == 3
|
||||||
|
assert all(c > 0 for c in counts)
|
||||||
|
|
||||||
|
def test_cache_stats(self) -> None:
|
||||||
|
"""Test cache statistics."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
stats = calc.get_cache_stats()
|
||||||
|
assert stats["enabled"] is True
|
||||||
|
assert stats["size"] == 0
|
||||||
|
assert stats["hits"] == 0
|
||||||
|
assert stats["misses"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit_rate(self) -> None:
|
||||||
|
"""Test cache hit rate tracking."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
# Make some calls
|
||||||
|
await calc.count_tokens("text1")
|
||||||
|
await calc.count_tokens("text2")
|
||||||
|
await calc.count_tokens("text1") # Cache hit
|
||||||
|
|
||||||
|
stats = calc.get_cache_stats()
|
||||||
|
assert stats["hits"] == 1
|
||||||
|
assert stats["misses"] == 2
|
||||||
|
|
||||||
|
def test_clear_cache(self) -> None:
|
||||||
|
"""Test cache clearing."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
calc._cache["test"] = 100
|
||||||
|
calc._cache_hits = 5
|
||||||
|
|
||||||
|
calc.clear_cache()
|
||||||
|
|
||||||
|
assert len(calc._cache) == 0
|
||||||
|
assert calc._cache_hits == 0
|
||||||
|
|
||||||
|
def test_set_mcp_manager(self) -> None:
|
||||||
|
"""Test setting MCP manager after initialization."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
assert calc._mcp is None
|
||||||
|
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
calc.set_mcp_manager(mock_mcp)
|
||||||
|
|
||||||
|
assert calc._mcp is mock_mcp
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_token_count_formats(self) -> None:
|
||||||
|
"""Test parsing different token count response formats."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
# Dict with token_count
|
||||||
|
assert calc._parse_token_count({"token_count": 42}) == 42
|
||||||
|
|
||||||
|
# Dict with tokens
|
||||||
|
assert calc._parse_token_count({"tokens": 42}) == 42
|
||||||
|
|
||||||
|
# Dict with count
|
||||||
|
assert calc._parse_token_count({"count": 42}) == 42
|
||||||
|
|
||||||
|
# Direct int
|
||||||
|
assert calc._parse_token_count(42) == 42
|
||||||
|
|
||||||
|
# JSON string
|
||||||
|
assert calc._parse_token_count('{"token_count": 42}') == 42
|
||||||
|
|
||||||
|
# Invalid
|
||||||
|
assert calc._parse_token_count("invalid") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestBudgetIntegration:
|
||||||
|
"""Integration tests for budget management."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_budget_workflow(self) -> None:
|
||||||
|
"""Test complete budget allocation workflow."""
|
||||||
|
# Create settings and allocator
|
||||||
|
settings = ContextSettings()
|
||||||
|
allocator = BudgetAllocator(settings)
|
||||||
|
|
||||||
|
# Create budget for Claude
|
||||||
|
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||||
|
assert budget.total == 200000
|
||||||
|
|
||||||
|
# Create calculator (without MCP for test)
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
# Simulate context allocation
|
||||||
|
system_text = "You are a helpful assistant." * 10
|
||||||
|
system_tokens = await calc.count_tokens(system_text)
|
||||||
|
|
||||||
|
# Allocate
|
||||||
|
assert budget.can_fit(ContextType.SYSTEM, system_tokens)
|
||||||
|
budget.allocate(ContextType.SYSTEM, system_tokens)
|
||||||
|
|
||||||
|
# Check state
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == system_tokens
|
||||||
|
assert budget.remaining(ContextType.SYSTEM) == budget.system - system_tokens
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_budget_overflow_handling(self) -> None:
|
||||||
|
"""Test handling budget overflow."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(1000) # Small budget
|
||||||
|
|
||||||
|
# Try to allocate more than available
|
||||||
|
with pytest.raises(BudgetExceededError):
|
||||||
|
budget.allocate(ContextType.KNOWLEDGE, 500)
|
||||||
|
|
||||||
|
# Force allocation should work
|
||||||
|
budget.allocate(ContextType.KNOWLEDGE, 500, force=True)
|
||||||
|
assert budget.get_used(ContextType.KNOWLEDGE) == 500
|
||||||
Reference in New Issue
Block a user