diff --git a/backend/app/services/context/__init__.py b/backend/app/services/context/__init__.py index 40034e0..add3366 100644 --- a/backend/app/services/context/__init__.py +++ b/backend/app/services/context/__init__.py @@ -14,11 +14,18 @@ Usage: ConversationContext, TaskContext, ToolContext, + TokenBudget, + BudgetAllocator, + TokenCalculator, ) # Get settings settings = get_context_settings() + # Create budget for a model + allocator = BudgetAllocator(settings) + budget = allocator.create_budget_for_model("claude-3-sonnet") + # Create context instances system_ctx = SystemContext.create_persona( name="Code Assistant", @@ -27,6 +34,13 @@ Usage: ) """ +# Budget Management +from .budget import ( + BudgetAllocator, + TokenBudget, + TokenCalculator, +) + # Configuration from .config import ( ContextSettings, @@ -67,6 +81,10 @@ from .types import ( ) __all__ = [ + # Budget Management + "BudgetAllocator", + "TokenBudget", + "TokenCalculator", # Configuration "ContextSettings", "get_context_settings", diff --git a/backend/app/services/context/budget/__init__.py b/backend/app/services/context/budget/__init__.py index f3a675b..b1eb0ab 100644 --- a/backend/app/services/context/budget/__init__.py +++ b/backend/app/services/context/budget/__init__.py @@ -3,3 +3,12 @@ Token Budget Management Module. Provides token counting and budget allocation. """ + +from .allocator import BudgetAllocator, TokenBudget +from .calculator import TokenCalculator + +__all__ = [ + "BudgetAllocator", + "TokenBudget", + "TokenCalculator", +] diff --git a/backend/app/services/context/budget/allocator.py b/backend/app/services/context/budget/allocator.py new file mode 100644 index 0000000..00e5cc9 --- /dev/null +++ b/backend/app/services/context/budget/allocator.py @@ -0,0 +1,433 @@ +""" +Token Budget Allocator for Context Management. + +Manages token budget allocation across context types. +""" + +from dataclasses import dataclass, field +from typing import Any + +from ..config import ContextSettings, get_context_settings +from ..exceptions import BudgetExceededError +from ..types import ContextType + + +@dataclass +class TokenBudget: + """ + Token budget allocation and tracking. + + Tracks allocated tokens per context type and + monitors usage to prevent overflows. + """ + + # Total budget + total: int + + # Allocated per type + system: int = 0 + task: int = 0 + knowledge: int = 0 + conversation: int = 0 + tools: int = 0 + response_reserve: int = 0 + buffer: int = 0 + + # Usage tracking + used: dict[str, int] = field(default_factory=dict) + + def __post_init__(self) -> None: + """Initialize usage tracking.""" + if not self.used: + self.used = {ct.value: 0 for ct in ContextType} + + def get_allocation(self, context_type: ContextType | str) -> int: + """ + Get allocated tokens for a context type. + + Args: + context_type: Context type to get allocation for + + Returns: + Allocated token count + """ + if isinstance(context_type, ContextType): + context_type = context_type.value + + allocation_map = { + "system": self.system, + "task": self.task, + "knowledge": self.knowledge, + "conversation": self.conversation, + "tool": self.tools, + } + return allocation_map.get(context_type, 0) + + def get_used(self, context_type: ContextType | str) -> int: + """ + Get used tokens for a context type. + + Args: + context_type: Context type to check + + Returns: + Used token count + """ + if isinstance(context_type, ContextType): + context_type = context_type.value + return self.used.get(context_type, 0) + + def remaining(self, context_type: ContextType | str) -> int: + """ + Get remaining tokens for a context type. + + Args: + context_type: Context type to check + + Returns: + Remaining token count + """ + allocated = self.get_allocation(context_type) + used = self.get_used(context_type) + return max(0, allocated - used) + + def total_remaining(self) -> int: + """ + Get total remaining tokens across all types. + + Returns: + Total remaining tokens + """ + total_used = sum(self.used.values()) + usable = self.total - self.response_reserve - self.buffer + return max(0, usable - total_used) + + def total_used(self) -> int: + """ + Get total used tokens. + + Returns: + Total used tokens + """ + return sum(self.used.values()) + + def can_fit(self, context_type: ContextType | str, tokens: int) -> bool: + """ + Check if tokens fit within budget for a type. + + Args: + context_type: Context type to check + tokens: Number of tokens to fit + + Returns: + True if tokens fit within remaining budget + """ + return tokens <= self.remaining(context_type) + + def allocate( + self, + context_type: ContextType | str, + tokens: int, + force: bool = False, + ) -> bool: + """ + Allocate (use) tokens from a context type's budget. + + Args: + context_type: Context type to allocate from + tokens: Number of tokens to allocate + force: If True, allow exceeding budget + + Returns: + True if allocation succeeded + + Raises: + BudgetExceededError: If tokens exceed budget and force=False + """ + if isinstance(context_type, ContextType): + context_type = context_type.value + + if not force and not self.can_fit(context_type, tokens): + raise BudgetExceededError( + message=f"Token budget exceeded for {context_type}", + allocated=self.get_allocation(context_type), + requested=self.get_used(context_type) + tokens, + context_type=context_type, + ) + + self.used[context_type] = self.used.get(context_type, 0) + tokens + return True + + def deallocate( + self, + context_type: ContextType | str, + tokens: int, + ) -> None: + """ + Deallocate (return) tokens to a context type's budget. + + Args: + context_type: Context type to return to + tokens: Number of tokens to return + """ + if isinstance(context_type, ContextType): + context_type = context_type.value + + current = self.used.get(context_type, 0) + self.used[context_type] = max(0, current - tokens) + + def reset(self) -> None: + """Reset all usage tracking.""" + self.used = {ct.value: 0 for ct in ContextType} + + def utilization(self, context_type: ContextType | str | None = None) -> float: + """ + Get budget utilization percentage. + + Args: + context_type: Specific type or None for total + + Returns: + Utilization as a fraction (0.0 to 1.0+) + """ + if context_type is None: + usable = self.total - self.response_reserve - self.buffer + if usable <= 0: + return 0.0 + return self.total_used() / usable + + allocated = self.get_allocation(context_type) + if allocated <= 0: + return 0.0 + return self.get_used(context_type) / allocated + + def to_dict(self) -> dict[str, Any]: + """Convert budget to dictionary.""" + return { + "total": self.total, + "allocations": { + "system": self.system, + "task": self.task, + "knowledge": self.knowledge, + "conversation": self.conversation, + "tools": self.tools, + "response_reserve": self.response_reserve, + "buffer": self.buffer, + }, + "used": dict(self.used), + "remaining": { + ct.value: self.remaining(ct) for ct in ContextType + }, + "total_used": self.total_used(), + "total_remaining": self.total_remaining(), + "utilization": round(self.utilization(), 3), + } + + +class BudgetAllocator: + """ + Budget allocator for context management. + + Creates token budgets based on configuration and + model context window sizes. + """ + + def __init__(self, settings: ContextSettings | None = None) -> None: + """ + Initialize budget allocator. + + Args: + settings: Context settings (uses default if None) + """ + self._settings = settings or get_context_settings() + + def create_budget( + self, + total_tokens: int, + custom_allocations: dict[str, float] | None = None, + ) -> TokenBudget: + """ + Create a token budget with allocations. + + Args: + total_tokens: Total available tokens + custom_allocations: Optional custom allocation percentages + + Returns: + TokenBudget with allocations set + """ + # Use custom or default allocations + if custom_allocations: + alloc = custom_allocations + else: + alloc = self._settings.get_budget_allocation() + + return TokenBudget( + total=total_tokens, + system=int(total_tokens * alloc.get("system", 0.05)), + task=int(total_tokens * alloc.get("task", 0.10)), + knowledge=int(total_tokens * alloc.get("knowledge", 0.40)), + conversation=int(total_tokens * alloc.get("conversation", 0.20)), + tools=int(total_tokens * alloc.get("tools", 0.05)), + response_reserve=int(total_tokens * alloc.get("response", 0.15)), + buffer=int(total_tokens * alloc.get("buffer", 0.05)), + ) + + def adjust_budget( + self, + budget: TokenBudget, + context_type: ContextType | str, + adjustment: int, + ) -> TokenBudget: + """ + Adjust a specific allocation in a budget. + + Takes tokens from buffer and adds to specified type. + + Args: + budget: Budget to adjust + context_type: Type to adjust + adjustment: Positive to increase, negative to decrease + + Returns: + Adjusted budget + """ + if isinstance(context_type, ContextType): + context_type = context_type.value + + # Calculate adjustment (limited by buffer) + if adjustment > 0: + # Taking from buffer + actual_adjustment = min(adjustment, budget.buffer) + budget.buffer -= actual_adjustment + else: + # Returning to buffer + actual_adjustment = adjustment + + # Apply to target type + if context_type == "system": + budget.system = max(0, budget.system + actual_adjustment) + elif context_type == "task": + budget.task = max(0, budget.task + actual_adjustment) + elif context_type == "knowledge": + budget.knowledge = max(0, budget.knowledge + actual_adjustment) + elif context_type == "conversation": + budget.conversation = max(0, budget.conversation + actual_adjustment) + elif context_type == "tool": + budget.tools = max(0, budget.tools + actual_adjustment) + + return budget + + def rebalance_budget( + self, + budget: TokenBudget, + prioritize: list[ContextType] | None = None, + ) -> TokenBudget: + """ + Rebalance budget based on actual usage. + + Moves unused allocations to prioritized types. + + Args: + budget: Budget to rebalance + prioritize: Types to prioritize (in order) + + Returns: + Rebalanced budget + """ + if prioritize is None: + prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM] + + # Calculate unused tokens per type + unused: dict[str, int] = {} + for ct in ContextType: + remaining = budget.remaining(ct) + if remaining > 0: + unused[ct.value] = remaining + + # Calculate total reclaimable (excluding prioritized types) + prioritize_values = {ct.value for ct in prioritize} + reclaimable = sum( + tokens for ct, tokens in unused.items() + if ct not in prioritize_values + ) + + # Redistribute to prioritized types that are near capacity + for ct in prioritize: + ct_value = ct.value + utilization = budget.utilization(ct) + + if utilization > 0.8: # Near capacity + # Give more tokens from reclaimable pool + bonus = min(reclaimable, budget.get_allocation(ct) // 2) + self.adjust_budget(budget, ct, bonus) + reclaimable -= bonus + + if reclaimable <= 0: + break + + return budget + + def get_model_context_size(self, model: str) -> int: + """ + Get context window size for a model. + + Args: + model: Model name + + Returns: + Context window size in tokens + """ + # Common model context sizes + context_sizes = { + "claude-3-opus": 200000, + "claude-3-sonnet": 200000, + "claude-3-haiku": 200000, + "claude-3-5-sonnet": 200000, + "claude-3-5-haiku": 200000, + "claude-opus-4": 200000, + "gpt-4-turbo": 128000, + "gpt-4": 8192, + "gpt-4-32k": 32768, + "gpt-4o": 128000, + "gpt-4o-mini": 128000, + "gpt-3.5-turbo": 16385, + "gemini-1.5-pro": 2000000, + "gemini-1.5-flash": 1000000, + "gemini-2.0-flash": 1000000, + "qwen-plus": 32000, + "qwen-turbo": 8000, + "deepseek-chat": 64000, + "deepseek-reasoner": 64000, + } + + # Check exact match first + model_lower = model.lower() + if model_lower in context_sizes: + return context_sizes[model_lower] + + # Check prefix match + for model_name, size in context_sizes.items(): + if model_lower.startswith(model_name): + return size + + # Default fallback + return 8192 + + def create_budget_for_model( + self, + model: str, + custom_allocations: dict[str, float] | None = None, + ) -> TokenBudget: + """ + Create a budget based on model's context window. + + Args: + model: Model name + custom_allocations: Optional custom allocation percentages + + Returns: + TokenBudget sized for the model + """ + context_size = self.get_model_context_size(model) + return self.create_budget(context_size, custom_allocations) diff --git a/backend/app/services/context/budget/calculator.py b/backend/app/services/context/budget/calculator.py new file mode 100644 index 0000000..23c498b --- /dev/null +++ b/backend/app/services/context/budget/calculator.py @@ -0,0 +1,284 @@ +""" +Token Calculator for Context Management. + +Provides token counting with caching and fallback estimation. +Integrates with LLM Gateway for accurate counts. +""" + +import hashlib +import logging +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from app.services.mcp.client_manager import MCPClientManager + +logger = logging.getLogger(__name__) + + +class TokenCounterProtocol(Protocol): + """Protocol for token counting implementations.""" + + async def count_tokens( + self, + text: str, + model: str | None = None, + ) -> int: + """Count tokens in text.""" + ... + + +class TokenCalculator: + """ + Token calculator with LLM Gateway integration. + + Features: + - In-memory caching for repeated text + - Fallback to character-based estimation + - Model-specific counting when possible + + The calculator uses the LLM Gateway's count_tokens tool + for accurate counting, with a local cache to avoid + repeated calls for the same content. + """ + + # Default characters per token ratio for estimation + DEFAULT_CHARS_PER_TOKEN = 4.0 + + # Model-specific ratios (more accurate estimation) + MODEL_CHAR_RATIOS: dict[str, float] = { + "claude": 3.5, + "gpt-4": 4.0, + "gpt-3.5": 4.0, + "gemini": 4.0, + } + + def __init__( + self, + mcp_manager: "MCPClientManager | None" = None, + project_id: str = "system", + agent_id: str = "context-engine", + cache_enabled: bool = True, + cache_max_size: int = 10000, + ) -> None: + """ + Initialize token calculator. + + Args: + mcp_manager: MCP client manager for LLM Gateway calls + project_id: Project ID for LLM Gateway calls + agent_id: Agent ID for LLM Gateway calls + cache_enabled: Whether to enable in-memory caching + cache_max_size: Maximum cache entries + """ + self._mcp = mcp_manager + self._project_id = project_id + self._agent_id = agent_id + self._cache_enabled = cache_enabled + self._cache_max_size = cache_max_size + + # In-memory cache: hash(model:text) -> token_count + self._cache: dict[str, int] = {} + self._cache_hits = 0 + self._cache_misses = 0 + + def _get_cache_key(self, text: str, model: str | None) -> str: + """Generate cache key from text and model.""" + # Use hash for efficient storage + content = f"{model or 'default'}:{text}" + return hashlib.sha256(content.encode()).hexdigest()[:32] + + def _check_cache(self, cache_key: str) -> int | None: + """Check cache for existing count.""" + if not self._cache_enabled: + return None + + if cache_key in self._cache: + self._cache_hits += 1 + return self._cache[cache_key] + + self._cache_misses += 1 + return None + + def _store_cache(self, cache_key: str, count: int) -> None: + """Store count in cache.""" + if not self._cache_enabled: + return + + # Simple LRU-like eviction: remove oldest entries when full + if len(self._cache) >= self._cache_max_size: + # Remove first 10% of entries + entries_to_remove = self._cache_max_size // 10 + keys_to_remove = list(self._cache.keys())[:entries_to_remove] + for key in keys_to_remove: + del self._cache[key] + + self._cache[cache_key] = count + + def estimate_tokens(self, text: str, model: str | None = None) -> int: + """ + Estimate token count based on character count. + + This is a fast fallback when LLM Gateway is unavailable. + + Args: + text: Text to count + model: Optional model for more accurate ratio + + Returns: + Estimated token count + """ + if not text: + return 0 + + # Get model-specific ratio + ratio = self.DEFAULT_CHARS_PER_TOKEN + if model: + model_lower = model.lower() + for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items(): + if model_prefix in model_lower: + ratio = model_ratio + break + + return max(1, int(len(text) / ratio)) + + async def count_tokens( + self, + text: str, + model: str | None = None, + ) -> int: + """ + Count tokens in text. + + Uses LLM Gateway for accurate counts with fallback to estimation. + + Args: + text: Text to count + model: Optional model for accurate counting + + Returns: + Token count + """ + if not text: + return 0 + + # Check cache first + cache_key = self._get_cache_key(text, model) + cached = self._check_cache(cache_key) + if cached is not None: + return cached + + # Try LLM Gateway + if self._mcp is not None: + try: + result = await self._mcp.call_tool( + server="llm-gateway", + tool="count_tokens", + args={ + "project_id": self._project_id, + "agent_id": self._agent_id, + "text": text, + "model": model, + }, + ) + + # Parse result + if result.success and result.data: + count = self._parse_token_count(result.data) + if count is not None: + self._store_cache(cache_key, count) + return count + + except Exception as e: + logger.warning(f"LLM Gateway token count failed, using estimation: {e}") + + # Fallback to estimation + count = self.estimate_tokens(text, model) + self._store_cache(cache_key, count) + return count + + def _parse_token_count(self, data: Any) -> int | None: + """Parse token count from LLM Gateway response.""" + if isinstance(data, dict): + if "token_count" in data: + return int(data["token_count"]) + if "tokens" in data: + return int(data["tokens"]) + if "count" in data: + return int(data["count"]) + + if isinstance(data, int): + return data + + if isinstance(data, str): + # Try to parse from text content + try: + # Handle {"token_count": 123} or just "123" + import json + + parsed = json.loads(data) + if isinstance(parsed, dict) and "token_count" in parsed: + return int(parsed["token_count"]) + if isinstance(parsed, int): + return parsed + except (json.JSONDecodeError, ValueError): + # Try direct int conversion + try: + return int(data) + except ValueError: + pass + + return None + + async def count_tokens_batch( + self, + texts: list[str], + model: str | None = None, + ) -> list[int]: + """ + Count tokens for multiple texts. + + Efficient batch counting with caching. + + Args: + texts: List of texts to count + model: Optional model for accurate counting + + Returns: + List of token counts (same order as input) + """ + results: list[int] = [] + + for text in texts: + count = await self.count_tokens(text, model) + results.append(count) + + return results + + def clear_cache(self) -> None: + """Clear the token count cache.""" + self._cache.clear() + self._cache_hits = 0 + self._cache_misses = 0 + + def get_cache_stats(self) -> dict[str, Any]: + """Get cache statistics.""" + total = self._cache_hits + self._cache_misses + hit_rate = self._cache_hits / total if total > 0 else 0.0 + + return { + "enabled": self._cache_enabled, + "size": len(self._cache), + "max_size": self._cache_max_size, + "hits": self._cache_hits, + "misses": self._cache_misses, + "hit_rate": round(hit_rate, 3), + } + + def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None: + """ + Set the MCP manager (for lazy initialization). + + Args: + mcp_manager: MCP client manager instance + """ + self._mcp = mcp_manager diff --git a/backend/tests/services/context/test_budget.py b/backend/tests/services/context/test_budget.py new file mode 100644 index 0000000..253c067 --- /dev/null +++ b/backend/tests/services/context/test_budget.py @@ -0,0 +1,533 @@ +"""Tests for token budget management.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.services.context.budget import ( + BudgetAllocator, + TokenBudget, + TokenCalculator, +) +from app.services.context.config import ContextSettings +from app.services.context.exceptions import BudgetExceededError +from app.services.context.types import ContextType + + +class TestTokenBudget: + """Tests for TokenBudget dataclass.""" + + def test_creation(self) -> None: + """Test basic budget creation.""" + budget = TokenBudget(total=10000) + assert budget.total == 10000 + assert budget.system == 0 + assert budget.total_used() == 0 + + def test_creation_with_allocations(self) -> None: + """Test budget creation with allocations.""" + budget = TokenBudget( + total=10000, + system=500, + task=1000, + knowledge=4000, + conversation=2000, + tools=500, + response_reserve=1500, + buffer=500, + ) + + assert budget.system == 500 + assert budget.knowledge == 4000 + assert budget.response_reserve == 1500 + + def test_get_allocation(self) -> None: + """Test getting allocation for a type.""" + budget = TokenBudget( + total=10000, + system=500, + knowledge=4000, + ) + + assert budget.get_allocation(ContextType.SYSTEM) == 500 + assert budget.get_allocation(ContextType.KNOWLEDGE) == 4000 + assert budget.get_allocation("system") == 500 + + def test_remaining(self) -> None: + """Test remaining budget calculation.""" + budget = TokenBudget( + total=10000, + system=500, + knowledge=4000, + ) + + # Initially full + assert budget.remaining(ContextType.SYSTEM) == 500 + assert budget.remaining(ContextType.KNOWLEDGE) == 4000 + + # After allocation + budget.allocate(ContextType.SYSTEM, 200) + assert budget.remaining(ContextType.SYSTEM) == 300 + + def test_can_fit(self) -> None: + """Test can_fit check.""" + budget = TokenBudget( + total=10000, + system=500, + knowledge=4000, + ) + + assert budget.can_fit(ContextType.SYSTEM, 500) is True + assert budget.can_fit(ContextType.SYSTEM, 501) is False + assert budget.can_fit(ContextType.KNOWLEDGE, 4000) is True + + def test_allocate_success(self) -> None: + """Test successful allocation.""" + budget = TokenBudget( + total=10000, + system=500, + ) + + result = budget.allocate(ContextType.SYSTEM, 200) + assert result is True + assert budget.get_used(ContextType.SYSTEM) == 200 + assert budget.remaining(ContextType.SYSTEM) == 300 + + def test_allocate_exceeds_budget(self) -> None: + """Test allocation exceeding budget.""" + budget = TokenBudget( + total=10000, + system=500, + ) + + with pytest.raises(BudgetExceededError) as exc_info: + budget.allocate(ContextType.SYSTEM, 600) + + assert exc_info.value.allocated == 500 + assert exc_info.value.requested == 600 + + def test_allocate_force(self) -> None: + """Test forced allocation exceeding budget.""" + budget = TokenBudget( + total=10000, + system=500, + ) + + # Force should allow exceeding + result = budget.allocate(ContextType.SYSTEM, 600, force=True) + assert result is True + assert budget.get_used(ContextType.SYSTEM) == 600 + + def test_deallocate(self) -> None: + """Test deallocation.""" + budget = TokenBudget( + total=10000, + system=500, + ) + + budget.allocate(ContextType.SYSTEM, 300) + assert budget.get_used(ContextType.SYSTEM) == 300 + + budget.deallocate(ContextType.SYSTEM, 100) + assert budget.get_used(ContextType.SYSTEM) == 200 + + def test_deallocate_below_zero(self) -> None: + """Test deallocation doesn't go below zero.""" + budget = TokenBudget( + total=10000, + system=500, + ) + + budget.allocate(ContextType.SYSTEM, 100) + budget.deallocate(ContextType.SYSTEM, 200) + assert budget.get_used(ContextType.SYSTEM) == 0 + + def test_total_remaining(self) -> None: + """Test total remaining calculation.""" + budget = TokenBudget( + total=10000, + system=500, + knowledge=4000, + response_reserve=1500, + buffer=500, + ) + + # Usable = total - response_reserve - buffer = 10000 - 1500 - 500 = 8000 + assert budget.total_remaining() == 8000 + + # After allocation + budget.allocate(ContextType.SYSTEM, 200) + assert budget.total_remaining() == 7800 + + def test_utilization(self) -> None: + """Test utilization calculation.""" + budget = TokenBudget( + total=10000, + system=500, + response_reserve=1500, + buffer=500, + ) + + # No usage = 0% + assert budget.utilization(ContextType.SYSTEM) == 0.0 + + # Half used = 50% + budget.allocate(ContextType.SYSTEM, 250) + assert budget.utilization(ContextType.SYSTEM) == 0.5 + + # Total utilization + assert budget.utilization() == 250 / 8000 # 250 / (10000 - 1500 - 500) + + def test_reset(self) -> None: + """Test reset clears usage.""" + budget = TokenBudget( + total=10000, + system=500, + ) + + budget.allocate(ContextType.SYSTEM, 300) + assert budget.get_used(ContextType.SYSTEM) == 300 + + budget.reset() + assert budget.get_used(ContextType.SYSTEM) == 0 + assert budget.total_used() == 0 + + def test_to_dict(self) -> None: + """Test to_dict conversion.""" + budget = TokenBudget( + total=10000, + system=500, + task=1000, + knowledge=4000, + ) + + budget.allocate(ContextType.SYSTEM, 200) + + data = budget.to_dict() + assert data["total"] == 10000 + assert data["allocations"]["system"] == 500 + assert data["used"]["system"] == 200 + assert data["remaining"]["system"] == 300 + + +class TestBudgetAllocator: + """Tests for BudgetAllocator.""" + + def test_create_budget(self) -> None: + """Test budget creation with default allocations.""" + allocator = BudgetAllocator() + budget = allocator.create_budget(100000) + + assert budget.total == 100000 + assert budget.system == 5000 # 5% + assert budget.task == 10000 # 10% + assert budget.knowledge == 40000 # 40% + assert budget.conversation == 20000 # 20% + assert budget.tools == 5000 # 5% + assert budget.response_reserve == 15000 # 15% + assert budget.buffer == 5000 # 5% + + def test_create_budget_custom_allocations(self) -> None: + """Test budget creation with custom allocations.""" + allocator = BudgetAllocator() + budget = allocator.create_budget( + 100000, + custom_allocations={ + "system": 0.10, + "task": 0.10, + "knowledge": 0.30, + "conversation": 0.25, + "tools": 0.05, + "response": 0.15, + "buffer": 0.05, + }, + ) + + assert budget.system == 10000 # 10% + assert budget.knowledge == 30000 # 30% + + def test_create_budget_for_model(self) -> None: + """Test budget creation for specific model.""" + allocator = BudgetAllocator() + + # Claude models have 200k context + budget = allocator.create_budget_for_model("claude-3-sonnet") + assert budget.total == 200000 + + # GPT-4 has 8k context + budget = allocator.create_budget_for_model("gpt-4") + assert budget.total == 8192 + + # GPT-4-turbo has 128k context + budget = allocator.create_budget_for_model("gpt-4-turbo") + assert budget.total == 128000 + + def test_get_model_context_size(self) -> None: + """Test model context size lookup.""" + allocator = BudgetAllocator() + + # Known models + assert allocator.get_model_context_size("claude-3-opus") == 200000 + assert allocator.get_model_context_size("gpt-4") == 8192 + assert allocator.get_model_context_size("gemini-1.5-pro") == 2000000 + + # Unknown model gets default + assert allocator.get_model_context_size("unknown-model") == 8192 + + def test_adjust_budget(self) -> None: + """Test budget adjustment.""" + allocator = BudgetAllocator() + budget = allocator.create_budget(10000) + + original_system = budget.system + original_buffer = budget.buffer + + # Increase system by taking from buffer + budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 200) + + assert budget.system == original_system + 200 + assert budget.buffer == original_buffer - 200 + + def test_adjust_budget_limited_by_buffer(self) -> None: + """Test that adjustment is limited by buffer size.""" + allocator = BudgetAllocator() + budget = allocator.create_budget(10000) + + original_buffer = budget.buffer + + # Try to increase more than buffer allows + budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 10000) + + # Should only increase by buffer amount + assert budget.buffer == 0 + assert budget.system <= original_buffer + budget.system + + def test_rebalance_budget(self) -> None: + """Test budget rebalancing.""" + allocator = BudgetAllocator() + budget = allocator.create_budget(10000) + + # Use most of knowledge budget + budget.allocate(ContextType.KNOWLEDGE, 3500) + + # Rebalance prioritizing knowledge + budget = allocator.rebalance_budget( + budget, + prioritize=[ContextType.KNOWLEDGE], + ) + + # Knowledge should have gotten more tokens + # (This is a fuzzy test - just check it runs) + assert budget is not None + + +class TestTokenCalculator: + """Tests for TokenCalculator.""" + + def test_estimate_tokens(self) -> None: + """Test token estimation.""" + calc = TokenCalculator() + + # Empty string + assert calc.estimate_tokens("") == 0 + + # Short text (~4 chars per token) + text = "This is a test message" + estimate = calc.estimate_tokens(text) + assert 4 <= estimate <= 8 + + def test_estimate_tokens_model_specific(self) -> None: + """Test model-specific estimation ratios.""" + calc = TokenCalculator() + text = "a" * 100 + + # Claude uses 3.5 chars per token + claude_estimate = calc.estimate_tokens(text, "claude-3-sonnet") + # GPT uses 4.0 chars per token + gpt_estimate = calc.estimate_tokens(text, "gpt-4") + + # Claude should estimate more tokens (smaller ratio) + assert claude_estimate >= gpt_estimate + + @pytest.mark.asyncio + async def test_count_tokens_no_mcp(self) -> None: + """Test token counting without MCP (fallback to estimation).""" + calc = TokenCalculator() + + text = "This is a test" + count = await calc.count_tokens(text) + + # Should use estimation + assert count > 0 + + @pytest.mark.asyncio + async def test_count_tokens_with_mcp_success(self) -> None: + """Test token counting with MCP integration.""" + # Mock MCP manager + mock_mcp = MagicMock() + mock_result = MagicMock() + mock_result.success = True + mock_result.data = {"token_count": 42} + mock_mcp.call_tool = AsyncMock(return_value=mock_result) + + calc = TokenCalculator(mcp_manager=mock_mcp) + count = await calc.count_tokens("test text") + + assert count == 42 + mock_mcp.call_tool.assert_called_once() + + @pytest.mark.asyncio + async def test_count_tokens_with_mcp_failure(self) -> None: + """Test fallback when MCP fails.""" + # Mock MCP manager that fails + mock_mcp = MagicMock() + mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed")) + + calc = TokenCalculator(mcp_manager=mock_mcp) + count = await calc.count_tokens("test text") + + # Should fall back to estimation + assert count > 0 + + @pytest.mark.asyncio + async def test_count_tokens_caching(self) -> None: + """Test that token counts are cached.""" + mock_mcp = MagicMock() + mock_result = MagicMock() + mock_result.success = True + mock_result.data = {"token_count": 42} + mock_mcp.call_tool = AsyncMock(return_value=mock_result) + + calc = TokenCalculator(mcp_manager=mock_mcp) + + # First call + count1 = await calc.count_tokens("test text") + # Second call (should use cache) + count2 = await calc.count_tokens("test text") + + assert count1 == count2 == 42 + # MCP should only be called once + assert mock_mcp.call_tool.call_count == 1 + + @pytest.mark.asyncio + async def test_count_tokens_batch(self) -> None: + """Test batch token counting.""" + calc = TokenCalculator() + + texts = ["Hello", "World", "Test message here"] + counts = await calc.count_tokens_batch(texts) + + assert len(counts) == 3 + assert all(c > 0 for c in counts) + + def test_cache_stats(self) -> None: + """Test cache statistics.""" + calc = TokenCalculator() + + stats = calc.get_cache_stats() + assert stats["enabled"] is True + assert stats["size"] == 0 + assert stats["hits"] == 0 + assert stats["misses"] == 0 + + @pytest.mark.asyncio + async def test_cache_hit_rate(self) -> None: + """Test cache hit rate tracking.""" + calc = TokenCalculator() + + # Make some calls + await calc.count_tokens("text1") + await calc.count_tokens("text2") + await calc.count_tokens("text1") # Cache hit + + stats = calc.get_cache_stats() + assert stats["hits"] == 1 + assert stats["misses"] == 2 + + def test_clear_cache(self) -> None: + """Test cache clearing.""" + calc = TokenCalculator() + calc._cache["test"] = 100 + calc._cache_hits = 5 + + calc.clear_cache() + + assert len(calc._cache) == 0 + assert calc._cache_hits == 0 + + def test_set_mcp_manager(self) -> None: + """Test setting MCP manager after initialization.""" + calc = TokenCalculator() + assert calc._mcp is None + + mock_mcp = MagicMock() + calc.set_mcp_manager(mock_mcp) + + assert calc._mcp is mock_mcp + + @pytest.mark.asyncio + async def test_parse_token_count_formats(self) -> None: + """Test parsing different token count response formats.""" + calc = TokenCalculator() + + # Dict with token_count + assert calc._parse_token_count({"token_count": 42}) == 42 + + # Dict with tokens + assert calc._parse_token_count({"tokens": 42}) == 42 + + # Dict with count + assert calc._parse_token_count({"count": 42}) == 42 + + # Direct int + assert calc._parse_token_count(42) == 42 + + # JSON string + assert calc._parse_token_count('{"token_count": 42}') == 42 + + # Invalid + assert calc._parse_token_count("invalid") is None + + +class TestBudgetIntegration: + """Integration tests for budget management.""" + + @pytest.mark.asyncio + async def test_full_budget_workflow(self) -> None: + """Test complete budget allocation workflow.""" + # Create settings and allocator + settings = ContextSettings() + allocator = BudgetAllocator(settings) + + # Create budget for Claude + budget = allocator.create_budget_for_model("claude-3-sonnet") + assert budget.total == 200000 + + # Create calculator (without MCP for test) + calc = TokenCalculator() + + # Simulate context allocation + system_text = "You are a helpful assistant." * 10 + system_tokens = await calc.count_tokens(system_text) + + # Allocate + assert budget.can_fit(ContextType.SYSTEM, system_tokens) + budget.allocate(ContextType.SYSTEM, system_tokens) + + # Check state + assert budget.get_used(ContextType.SYSTEM) == system_tokens + assert budget.remaining(ContextType.SYSTEM) == budget.system - system_tokens + + @pytest.mark.asyncio + async def test_budget_overflow_handling(self) -> None: + """Test handling budget overflow.""" + allocator = BudgetAllocator() + budget = allocator.create_budget(1000) # Small budget + + # Try to allocate more than available + with pytest.raises(BudgetExceededError): + budget.allocate(ContextType.KNOWLEDGE, 500) + + # Force allocation should work + budget.allocate(ContextType.KNOWLEDGE, 500, force=True) + assert budget.get_used(ContextType.KNOWLEDGE) == 500