feat(llm-gateway): implement LLM Gateway MCP Server (#56)

Implements complete LLM Gateway MCP Server with:
- FastMCP server with 4 tools: chat_completion, list_models, get_usage, count_tokens
- LiteLLM Router with multi-provider failover chains
- Circuit breaker pattern for fault tolerance
- Redis-based cost tracking per project/agent
- Comprehensive test suite (209 tests, 92% coverage)

Model groups defined per ADR-004:
- reasoning: claude-opus-4 → gpt-4.1 → gemini-2.5-pro
- code: claude-sonnet-4 → gpt-4.1 → deepseek-coder
- fast: claude-haiku → gpt-4.1-mini → gemini-2.0-flash

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 20:31:19 +01:00
parent 746fb7b181
commit 6e8b0b022a
23 changed files with 9794 additions and 93 deletions

View File

@@ -0,0 +1,53 @@
# Syndarix LLM Gateway MCP Server
# Multi-stage build for minimal image size
# Build stage
FROM python:3.12-slim AS builder
# Install uv for fast package management
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
# Copy dependency files
COPY pyproject.toml ./
# Create virtual environment and install dependencies
RUN uv venv /app/.venv
ENV PATH="/app/.venv/bin:$PATH"
RUN uv pip install -e .
# Runtime stage
FROM python:3.12-slim AS runtime
# Create non-root user for security
RUN groupadd --gid 1000 appgroup && \
useradd --uid 1000 --gid appgroup --shell /bin/bash --create-home appuser
WORKDIR /app
# Copy virtual environment from builder
COPY --from=builder /app/.venv /app/.venv
ENV PATH="/app/.venv/bin:$PATH"
# Copy application code
COPY --chown=appuser:appgroup . .
# Switch to non-root user
USER appuser
# Environment variables
ENV LLM_GATEWAY_HOST=0.0.0.0
ENV LLM_GATEWAY_PORT=8001
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Expose port
EXPOSE 8001
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()" || exit 1
# Run the server
CMD ["python", "server.py"]

View File

@@ -0,0 +1,179 @@
"""
Configuration for LLM Gateway MCP Server.
Uses Pydantic Settings for type-safe environment variable handling.
"""
from functools import lru_cache
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""LLM Gateway configuration settings."""
model_config = SettingsConfigDict(
env_prefix="LLM_GATEWAY_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# Server settings
host: str = Field(default="0.0.0.0", description="Server host")
port: int = Field(default=8001, description="Server port")
debug: bool = Field(default=False, description="Debug mode")
# Redis settings
redis_url: str = Field(
default="redis://localhost:6379/0",
description="Redis connection URL",
)
redis_prefix: str = Field(
default="llm_gateway",
description="Redis key prefix",
)
redis_ttl_hours: int = Field(
default=24,
description="Default Redis TTL in hours",
)
# Provider API keys
anthropic_api_key: str | None = Field(
default=None,
description="Anthropic API key for Claude models",
)
openai_api_key: str | None = Field(
default=None,
description="OpenAI API key for GPT models",
)
google_api_key: str | None = Field(
default=None,
description="Google API key for Gemini models",
)
alibaba_api_key: str | None = Field(
default=None,
description="Alibaba API key for Qwen models",
)
deepseek_api_key: str | None = Field(
default=None,
description="DeepSeek API key",
)
deepseek_base_url: str | None = Field(
default=None,
description="DeepSeek API base URL (for self-hosted)",
)
# LiteLLM settings
litellm_timeout: int = Field(
default=120,
description="LiteLLM request timeout in seconds",
)
litellm_max_retries: int = Field(
default=3,
description="Maximum retries per provider",
)
litellm_cache_enabled: bool = Field(
default=True,
description="Enable Redis caching for LiteLLM",
)
litellm_cache_ttl: int = Field(
default=3600,
description="Cache TTL in seconds",
)
# Circuit breaker settings
circuit_failure_threshold: int = Field(
default=5,
description="Failures before circuit opens",
)
circuit_recovery_timeout: int = Field(
default=60,
description="Seconds before circuit half-opens",
)
circuit_half_open_max_calls: int = Field(
default=3,
description="Max calls in half-open state",
)
# Cost tracking settings
cost_tracking_enabled: bool = Field(
default=True,
description="Enable cost tracking",
)
cost_alert_threshold: float = Field(
default=100.0,
description="Cost threshold for alerts (USD)",
)
default_budget_limit: float = Field(
default=1000.0,
description="Default project budget limit (USD)",
)
# Rate limiting
rate_limit_enabled: bool = Field(
default=True,
description="Enable rate limiting",
)
rate_limit_requests_per_minute: int = Field(
default=60,
description="Max requests per minute per project",
)
@field_validator("port")
@classmethod
def validate_port(cls, v: int) -> int:
"""Validate port is in valid range."""
if not 1 <= v <= 65535:
raise ValueError("Port must be between 1 and 65535")
return v
@field_validator("redis_ttl_hours")
@classmethod
def validate_ttl(cls, v: int) -> int:
"""Validate TTL is positive."""
if v <= 0:
raise ValueError("Redis TTL must be positive")
return v
@field_validator("circuit_failure_threshold")
@classmethod
def validate_failure_threshold(cls, v: int) -> int:
"""Validate failure threshold is reasonable."""
if not 1 <= v <= 100:
raise ValueError("Failure threshold must be between 1 and 100")
return v
@field_validator("litellm_timeout")
@classmethod
def validate_timeout(cls, v: int) -> int:
"""Validate timeout is reasonable."""
if not 1 <= v <= 600:
raise ValueError("Timeout must be between 1 and 600 seconds")
return v
def get_available_providers(self) -> list[str]:
"""Get list of providers with configured API keys."""
providers = []
if self.anthropic_api_key:
providers.append("anthropic")
if self.openai_api_key:
providers.append("openai")
if self.google_api_key:
providers.append("google")
if self.alibaba_api_key:
providers.append("alibaba")
if self.deepseek_api_key or self.deepseek_base_url:
providers.append("deepseek")
return providers
def has_any_provider(self) -> bool:
"""Check if at least one provider is configured."""
return len(self.get_available_providers()) > 0
@lru_cache
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()

View File

@@ -0,0 +1,467 @@
"""
Cost tracking for LLM Gateway.
Tracks LLM usage costs per project and agent using Redis.
Provides aggregation by hour, day, and month with TTL-based expiry.
"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
import redis.asyncio as redis
from config import Settings, get_settings
from models import (
MODEL_CONFIGS,
UsageReport,
)
logger = logging.getLogger(__name__)
class CostTracker:
"""
Redis-based cost tracker for LLM usage.
Key structure:
- {prefix}:cost:project:{project_id}:{date} -> Hash of usage by model
- {prefix}:cost:agent:{agent_id}:{date} -> Hash of usage by model
- {prefix}:cost:session:{session_id} -> Hash of session usage
- {prefix}:requests:{project_id}:{date} -> Request count
Date formats:
- hour: YYYYMMDDHH
- day: YYYYMMDD
- month: YYYYMM
"""
def __init__(
self,
redis_client: redis.Redis | None = None,
settings: Settings | None = None,
) -> None:
"""
Initialize cost tracker.
Args:
redis_client: Redis client (creates one if None)
settings: Application settings
"""
self._settings = settings or get_settings()
self._redis: redis.Redis | None = redis_client
self._prefix = self._settings.redis_prefix
async def _get_redis(self) -> redis.Redis:
"""Get or create Redis client."""
if self._redis is None:
self._redis = redis.from_url(
self._settings.redis_url,
decode_responses=True,
)
return self._redis
async def close(self) -> None:
"""Close Redis connection."""
if self._redis:
await self._redis.aclose()
self._redis = None
def _get_date_keys(self, timestamp: datetime | None = None) -> dict[str, str]:
"""Get date format keys for different periods."""
if timestamp is None:
timestamp = datetime.now(UTC)
return {
"hour": timestamp.strftime("%Y%m%d%H"),
"day": timestamp.strftime("%Y%m%d"),
"month": timestamp.strftime("%Y%m"),
}
def _get_ttl_seconds(self, period: str) -> int:
"""Get TTL in seconds for a period."""
ttls = {
"hour": 24 * 3600, # 24 hours
"day": 30 * 24 * 3600, # 30 days
"month": 365 * 24 * 3600, # 1 year
}
return ttls.get(period, 30 * 24 * 3600)
async def record_usage(
self,
project_id: str,
agent_id: str,
model: str,
prompt_tokens: int,
completion_tokens: int,
cost_usd: float,
session_id: str | None = None,
request_id: str | None = None, # noqa: ARG002 - reserved for future logging
) -> None:
"""
Record LLM usage.
Args:
project_id: Project ID
agent_id: Agent ID
model: Model name
prompt_tokens: Input tokens
completion_tokens: Output tokens
cost_usd: Cost in USD
session_id: Optional session ID
request_id: Optional request ID
"""
if not self._settings.cost_tracking_enabled:
return
r = await self._get_redis()
date_keys = self._get_date_keys()
pipe = r.pipeline()
# Record for each time period
for period, date_key in date_keys.items():
# Project-level tracking
project_key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
await self._increment_usage(
pipe, project_key, model, prompt_tokens, completion_tokens, cost_usd
)
pipe.expire(project_key, self._get_ttl_seconds(period))
# Agent-level tracking
agent_key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
await self._increment_usage(
pipe, agent_key, model, prompt_tokens, completion_tokens, cost_usd
)
pipe.expire(agent_key, self._get_ttl_seconds(period))
# Request counter
requests_key = f"{self._prefix}:requests:{project_id}:{date_key}"
pipe.incr(requests_key)
pipe.expire(requests_key, self._get_ttl_seconds(period))
# Session tracking (if session_id provided)
if session_id:
session_key = f"{self._prefix}:cost:session:{session_id}"
await self._increment_usage(
pipe, session_key, model, prompt_tokens, completion_tokens, cost_usd
)
pipe.expire(session_key, 24 * 3600) # 24 hour TTL for sessions
await pipe.execute()
logger.debug(
f"Recorded usage: project={project_id}, agent={agent_id}, "
f"model={model}, tokens={prompt_tokens + completion_tokens}, "
f"cost=${cost_usd:.6f}"
)
async def _increment_usage(
self,
pipe: redis.client.Pipeline,
key: str,
model: str,
prompt_tokens: int,
completion_tokens: int,
cost_usd: float,
) -> None:
"""Increment usage in a hash."""
# Store as JSON fields within the hash
pipe.hincrby(key, f"{model}:prompt_tokens", prompt_tokens)
pipe.hincrby(key, f"{model}:completion_tokens", completion_tokens)
pipe.hincrbyfloat(key, f"{model}:cost_usd", cost_usd)
pipe.hincrby(key, f"{model}:requests", 1)
# Totals
pipe.hincrby(key, "total:prompt_tokens", prompt_tokens)
pipe.hincrby(key, "total:completion_tokens", completion_tokens)
pipe.hincrbyfloat(key, "total:cost_usd", cost_usd)
pipe.hincrby(key, "total:requests", 1)
async def get_project_usage(
self,
project_id: str,
period: str = "day",
timestamp: datetime | None = None,
) -> UsageReport:
"""
Get usage report for a project.
Args:
project_id: Project ID
period: Time period (hour, day, month)
timestamp: Specific time to query (defaults to now)
Returns:
Usage report
"""
date_keys = self._get_date_keys(timestamp)
date_key = date_keys.get(period, date_keys["day"])
key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
return await self._get_usage_report(
key, project_id, "project", period, timestamp
)
async def get_agent_usage(
self,
agent_id: str,
period: str = "day",
timestamp: datetime | None = None,
) -> UsageReport:
"""
Get usage report for an agent.
Args:
agent_id: Agent ID
period: Time period (hour, day, month)
timestamp: Specific time to query (defaults to now)
Returns:
Usage report
"""
date_keys = self._get_date_keys(timestamp)
date_key = date_keys.get(period, date_keys["day"])
key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
return await self._get_usage_report(key, agent_id, "agent", period, timestamp)
async def _get_usage_report(
self,
key: str,
entity_id: str,
entity_type: str,
period: str,
timestamp: datetime | None,
) -> UsageReport:
"""Get usage report from a Redis hash."""
r = await self._get_redis()
data = await r.hgetall(key)
# Parse the hash data
by_model: dict[str, dict[str, Any]] = {}
total_requests = 0
total_tokens = 0
total_cost = 0.0
for field, value in data.items():
parts = field.split(":")
if len(parts) != 2:
continue
model, metric = parts
if model == "total":
if metric == "requests":
total_requests = int(value)
elif metric == "prompt_tokens" or metric == "completion_tokens":
total_tokens += int(value)
elif metric == "cost_usd":
total_cost = float(value)
else:
if model not in by_model:
by_model[model] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"cost_usd": 0.0,
"requests": 0,
}
if metric == "prompt_tokens":
by_model[model]["prompt_tokens"] = int(value)
elif metric == "completion_tokens":
by_model[model]["completion_tokens"] = int(value)
elif metric == "cost_usd":
by_model[model]["cost_usd"] = float(value)
elif metric == "requests":
by_model[model]["requests"] = int(value)
# Calculate period boundaries
now = timestamp or datetime.now(UTC)
if period == "hour":
period_start = now.replace(minute=0, second=0, microsecond=0)
period_end = period_start + timedelta(hours=1)
elif period == "day":
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
period_end = period_start + timedelta(days=1)
else: # month
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
# Next month
if now.month == 12:
period_end = period_start.replace(year=now.year + 1, month=1)
else:
period_end = period_start.replace(month=now.month + 1)
return UsageReport(
entity_id=entity_id,
entity_type=entity_type,
period=period,
period_start=period_start,
period_end=period_end,
total_requests=total_requests,
total_tokens=total_tokens,
total_cost_usd=round(total_cost, 6),
by_model=by_model,
)
async def get_session_usage(self, session_id: str) -> dict[str, Any]:
"""
Get usage for a specific session.
Args:
session_id: Session ID
Returns:
Session usage data
"""
r = await self._get_redis()
key = f"{self._prefix}:cost:session:{session_id}"
data = await r.hgetall(key)
# Parse similar to _get_usage_report
result: dict[str, Any] = {
"session_id": session_id,
"total_tokens": 0,
"total_cost_usd": 0.0,
"by_model": {},
}
for field, value in data.items():
parts = field.split(":")
if len(parts) != 2:
continue
model, metric = parts
if model == "total":
if metric == "prompt_tokens" or metric == "completion_tokens":
result["total_tokens"] += int(value)
elif metric == "cost_usd":
result["total_cost_usd"] = float(value)
else:
if model not in result["by_model"]:
result["by_model"][model] = {}
if metric in ("prompt_tokens", "completion_tokens", "requests"):
result["by_model"][model][metric] = int(value)
elif metric == "cost_usd":
result["by_model"][model][metric] = float(value)
return result
async def check_budget(
self,
project_id: str,
budget_limit: float | None = None,
) -> tuple[bool, float, float]:
"""
Check if project is within budget.
Args:
project_id: Project ID
budget_limit: Budget limit (uses default if None)
Returns:
Tuple of (within_budget, current_cost, limit)
"""
limit = budget_limit or self._settings.default_budget_limit
# Get current month usage
report = await self.get_project_usage(project_id, period="month")
current_cost = report.total_cost_usd
within_budget = current_cost < limit
return within_budget, current_cost, limit
async def estimate_request_cost(
self,
model: str,
prompt_tokens: int,
max_completion_tokens: int,
) -> float:
"""
Estimate cost for a request.
Args:
model: Model name
prompt_tokens: Input token count
max_completion_tokens: Maximum output tokens
Returns:
Estimated cost in USD
"""
config = MODEL_CONFIGS.get(model)
if not config:
# Use a default estimate
return (prompt_tokens + max_completion_tokens) * 0.00001
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
output_cost = (max_completion_tokens / 1_000_000) * config.cost_per_1m_output
return round(input_cost + output_cost, 6)
async def should_alert(
self,
project_id: str,
threshold: float | None = None,
) -> tuple[bool, float]:
"""
Check if cost alert should be triggered.
Args:
project_id: Project ID
threshold: Alert threshold (uses default if None)
Returns:
Tuple of (should_alert, current_cost)
"""
thresh = threshold or self._settings.cost_alert_threshold
report = await self.get_project_usage(project_id, period="day")
current_cost = report.total_cost_usd
return current_cost >= thresh, current_cost
def calculate_cost(
model: str,
prompt_tokens: int,
completion_tokens: int,
) -> float:
"""
Calculate cost for a completion.
Args:
model: Model name
prompt_tokens: Input tokens
completion_tokens: Output tokens
Returns:
Cost in USD
"""
config = MODEL_CONFIGS.get(model)
if not config:
logger.warning(f"Unknown model {model} for cost calculation")
return 0.0
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
output_cost = (completion_tokens / 1_000_000) * config.cost_per_1m_output
return round(input_cost + output_cost, 6)
# Global tracker instance (lazy initialization)
_tracker: CostTracker | None = None
def get_cost_tracker() -> CostTracker:
"""Get the global cost tracker instance."""
global _tracker
if _tracker is None:
_tracker = CostTracker()
return _tracker
async def close_cost_tracker() -> None:
"""Close the global cost tracker."""
global _tracker
if _tracker:
await _tracker.close()
_tracker = None
def reset_cost_tracker() -> None:
"""Reset the global tracker (for testing)."""
global _tracker
_tracker = None

View File

@@ -0,0 +1,478 @@
"""
Custom exceptions for LLM Gateway MCP Server.
Provides structured error handling with error codes for consistent responses.
"""
from enum import Enum
from typing import Any
class ErrorCode(str, Enum):
"""Error codes for LLM Gateway errors."""
# General errors
UNKNOWN_ERROR = "LLM_UNKNOWN_ERROR"
INVALID_REQUEST = "LLM_INVALID_REQUEST"
CONFIGURATION_ERROR = "LLM_CONFIGURATION_ERROR"
# Provider errors
PROVIDER_ERROR = "LLM_PROVIDER_ERROR"
PROVIDER_TIMEOUT = "LLM_PROVIDER_TIMEOUT"
PROVIDER_RATE_LIMIT = "LLM_PROVIDER_RATE_LIMIT"
PROVIDER_UNAVAILABLE = "LLM_PROVIDER_UNAVAILABLE"
ALL_PROVIDERS_FAILED = "LLM_ALL_PROVIDERS_FAILED"
# Model errors
INVALID_MODEL = "LLM_INVALID_MODEL"
INVALID_MODEL_GROUP = "LLM_INVALID_MODEL_GROUP"
MODEL_NOT_AVAILABLE = "LLM_MODEL_NOT_AVAILABLE"
# Circuit breaker errors
CIRCUIT_OPEN = "LLM_CIRCUIT_OPEN"
CIRCUIT_HALF_OPEN_EXHAUSTED = "LLM_CIRCUIT_HALF_OPEN_EXHAUSTED"
# Cost errors
COST_LIMIT_EXCEEDED = "LLM_COST_LIMIT_EXCEEDED"
BUDGET_EXHAUSTED = "LLM_BUDGET_EXHAUSTED"
# Rate limiting errors
RATE_LIMIT_EXCEEDED = "LLM_RATE_LIMIT_EXCEEDED"
# Streaming errors
STREAM_ERROR = "LLM_STREAM_ERROR"
STREAM_INTERRUPTED = "LLM_STREAM_INTERRUPTED"
# Token errors
TOKEN_LIMIT_EXCEEDED = "LLM_TOKEN_LIMIT_EXCEEDED"
CONTEXT_TOO_LONG = "LLM_CONTEXT_TOO_LONG"
class LLMGatewayError(Exception):
"""Base exception for LLM Gateway errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.UNKNOWN_ERROR,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
"""
Initialize LLM Gateway error.
Args:
message: Human-readable error message
code: Error code for programmatic handling
details: Additional error details
cause: Original exception that caused this error
"""
super().__init__(message)
self.message = message
self.code = code
self.details = details or {}
self.cause = cause
def to_dict(self) -> dict[str, Any]:
"""Convert error to dictionary for JSON response."""
result = {
"error": self.code.value,
"message": self.message,
}
if self.details:
result["details"] = self.details
return result
def __str__(self) -> str:
"""String representation."""
return f"[{self.code.value}] {self.message}"
def __repr__(self) -> str:
"""Detailed representation."""
return (
f"{self.__class__.__name__}("
f"message={self.message!r}, "
f"code={self.code.value!r}, "
f"details={self.details!r})"
)
class ProviderError(LLMGatewayError):
"""Error from an LLM provider."""
def __init__(
self,
message: str,
provider: str,
model: str | None = None,
status_code: int | None = None,
details: dict[str, Any] | None = None,
cause: Exception | None = None,
) -> None:
"""
Initialize provider error.
Args:
message: Error message
provider: Provider that failed
model: Model that was being used
status_code: HTTP status code if applicable
details: Additional details
cause: Original exception
"""
error_details = details or {}
error_details["provider"] = provider
if model:
error_details["model"] = model
if status_code:
error_details["status_code"] = status_code
super().__init__(
message=message,
code=ErrorCode.PROVIDER_ERROR,
details=error_details,
cause=cause,
)
self.provider = provider
self.model = model
self.status_code = status_code
class RateLimitError(LLMGatewayError):
"""Rate limit exceeded error."""
def __init__(
self,
message: str,
provider: str | None = None,
retry_after: int | None = None,
details: dict[str, Any] | None = None,
) -> None:
"""
Initialize rate limit error.
Args:
message: Error message
provider: Provider that rate limited (None for internal limit)
retry_after: Seconds until retry is allowed
details: Additional details
"""
error_details = details or {}
if provider:
error_details["provider"] = provider
if retry_after:
error_details["retry_after_seconds"] = retry_after
code = (
ErrorCode.PROVIDER_RATE_LIMIT
if provider
else ErrorCode.RATE_LIMIT_EXCEEDED
)
super().__init__(
message=message,
code=code,
details=error_details,
)
self.provider = provider
self.retry_after = retry_after
class CircuitOpenError(LLMGatewayError):
"""Circuit breaker is open, provider temporarily unavailable."""
def __init__(
self,
provider: str,
recovery_time: int | None = None,
details: dict[str, Any] | None = None,
) -> None:
"""
Initialize circuit open error.
Args:
provider: Provider with open circuit
recovery_time: Seconds until circuit may recover
details: Additional details
"""
error_details = details or {}
error_details["provider"] = provider
if recovery_time:
error_details["recovery_time_seconds"] = recovery_time
super().__init__(
message=f"Circuit breaker open for provider {provider}",
code=ErrorCode.CIRCUIT_OPEN,
details=error_details,
)
self.provider = provider
self.recovery_time = recovery_time
class CostLimitExceededError(LLMGatewayError):
"""Cost limit exceeded for project or agent."""
def __init__(
self,
entity_type: str,
entity_id: str,
current_cost: float,
limit: float,
details: dict[str, Any] | None = None,
) -> None:
"""
Initialize cost limit error.
Args:
entity_type: 'project' or 'agent'
entity_id: ID of the entity
current_cost: Current accumulated cost
limit: Cost limit that was exceeded
details: Additional details
"""
error_details = details or {}
error_details["entity_type"] = entity_type
error_details["entity_id"] = entity_id
error_details["current_cost_usd"] = current_cost
error_details["limit_usd"] = limit
super().__init__(
message=(
f"Cost limit exceeded for {entity_type} {entity_id}: "
f"${current_cost:.2f} >= ${limit:.2f}"
),
code=ErrorCode.COST_LIMIT_EXCEEDED,
details=error_details,
)
self.entity_type = entity_type
self.entity_id = entity_id
self.current_cost = current_cost
self.limit = limit
class InvalidModelGroupError(LLMGatewayError):
"""Invalid or unknown model group."""
def __init__(
self,
model_group: str,
available_groups: list[str] | None = None,
) -> None:
"""
Initialize invalid model group error.
Args:
model_group: The invalid group name
available_groups: List of valid group names
"""
details: dict[str, Any] = {"requested_group": model_group}
if available_groups:
details["available_groups"] = available_groups
super().__init__(
message=f"Invalid model group: {model_group}",
code=ErrorCode.INVALID_MODEL_GROUP,
details=details,
)
self.model_group = model_group
self.available_groups = available_groups
class InvalidModelError(LLMGatewayError):
"""Invalid or unknown model."""
def __init__(
self,
model: str,
reason: str | None = None,
) -> None:
"""
Initialize invalid model error.
Args:
model: The invalid model name
reason: Reason why it's invalid
"""
details: dict[str, Any] = {"requested_model": model}
if reason:
details["reason"] = reason
super().__init__(
message=f"Invalid model: {model}" + (f" ({reason})" if reason else ""),
code=ErrorCode.INVALID_MODEL,
details=details,
)
self.model = model
class ModelNotAvailableError(LLMGatewayError):
"""Model not available (provider not configured)."""
def __init__(
self,
model: str,
provider: str,
) -> None:
"""
Initialize model not available error.
Args:
model: The unavailable model
provider: The provider that's not configured
"""
super().__init__(
message=f"Model {model} not available: {provider} provider not configured",
code=ErrorCode.MODEL_NOT_AVAILABLE,
details={"model": model, "provider": provider},
)
self.model = model
self.provider = provider
class AllProvidersFailedError(LLMGatewayError):
"""All providers in the failover chain failed."""
def __init__(
self,
model_group: str,
attempted_models: list[str],
errors: list[dict[str, Any]],
) -> None:
"""
Initialize all providers failed error.
Args:
model_group: The model group that was requested
attempted_models: Models that were attempted
errors: Errors from each attempt
"""
super().__init__(
message=f"All providers failed for model group {model_group}",
code=ErrorCode.ALL_PROVIDERS_FAILED,
details={
"model_group": model_group,
"attempted_models": attempted_models,
"errors": errors,
},
)
self.model_group = model_group
self.attempted_models = attempted_models
self.errors = errors
class StreamError(LLMGatewayError):
"""Error during streaming response."""
def __init__(
self,
message: str,
chunks_received: int = 0,
cause: Exception | None = None,
) -> None:
"""
Initialize stream error.
Args:
message: Error message
chunks_received: Number of chunks received before error
cause: Original exception
"""
super().__init__(
message=message,
code=ErrorCode.STREAM_ERROR,
details={"chunks_received": chunks_received},
cause=cause,
)
self.chunks_received = chunks_received
class TokenLimitExceededError(LLMGatewayError):
"""Request exceeds model's token limit."""
def __init__(
self,
model: str,
token_count: int,
limit: int,
) -> None:
"""
Initialize token limit error.
Args:
model: Model name
token_count: Requested token count
limit: Model's token limit
"""
super().__init__(
message=f"Token count {token_count} exceeds {model} limit of {limit}",
code=ErrorCode.TOKEN_LIMIT_EXCEEDED,
details={
"model": model,
"requested_tokens": token_count,
"limit": limit,
},
)
self.model = model
self.token_count = token_count
self.limit = limit
class ContextTooLongError(LLMGatewayError):
"""Input context exceeds model's context window."""
def __init__(
self,
model: str,
context_length: int,
max_context: int,
) -> None:
"""
Initialize context too long error.
Args:
model: Model name
context_length: Input context length
max_context: Model's max context window
"""
super().__init__(
message=(
f"Context length {context_length} exceeds {model} "
f"context window of {max_context}"
),
code=ErrorCode.CONTEXT_TOO_LONG,
details={
"model": model,
"context_length": context_length,
"max_context": max_context,
},
)
self.model = model
self.context_length = context_length
self.max_context = max_context
class ConfigurationError(LLMGatewayError):
"""Configuration error."""
def __init__(
self,
message: str,
config_key: str | None = None,
) -> None:
"""
Initialize configuration error.
Args:
message: Error message
config_key: Configuration key that's problematic
"""
details: dict[str, Any] = {}
if config_key:
details["config_key"] = config_key
super().__init__(
message=message,
code=ErrorCode.CONFIGURATION_ERROR,
details=details,
)
self.config_key = config_key

View File

@@ -0,0 +1,357 @@
"""
Circuit Breaker implementation for LLM Gateway.
Provides fault tolerance by tracking provider failures and
temporarily disabling providers that are experiencing issues.
"""
import asyncio
import logging
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, TypeVar
from config import Settings, get_settings
from exceptions import CircuitOpenError
logger = logging.getLogger(__name__)
T = TypeVar("T")
class CircuitState(str, Enum):
"""Circuit breaker states."""
CLOSED = "closed" # Normal operation, requests pass through
OPEN = "open" # Failures exceeded threshold, requests blocked
HALF_OPEN = "half_open" # Testing if service recovered
@dataclass
class CircuitStats:
"""Statistics for a circuit breaker."""
failures: int = 0
successes: int = 0
last_failure_time: float | None = None
last_success_time: float | None = None
state_changed_at: float = field(default_factory=time.time)
half_open_calls: int = 0
class CircuitBreaker:
"""
Circuit breaker for individual providers.
States:
- CLOSED: Normal operation. Failures increment counter.
- OPEN: Too many failures. Requests immediately fail.
- HALF_OPEN: Testing recovery. Limited requests allowed.
Transitions:
- CLOSED -> OPEN: When failures >= threshold
- OPEN -> HALF_OPEN: After recovery_timeout
- HALF_OPEN -> CLOSED: On success
- HALF_OPEN -> OPEN: On failure
"""
def __init__(
self,
name: str,
failure_threshold: int = 5,
recovery_timeout: int = 60,
half_open_max_calls: int = 3,
) -> None:
"""
Initialize circuit breaker.
Args:
name: Identifier for this circuit (usually provider name)
failure_threshold: Failures before opening circuit
recovery_timeout: Seconds before attempting recovery
half_open_max_calls: Max calls allowed in half-open state
"""
self.name = name
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.half_open_max_calls = half_open_max_calls
self._state = CircuitState.CLOSED
self._stats = CircuitStats()
self._lock = asyncio.Lock()
@property
def state(self) -> CircuitState:
"""Get current circuit state (may trigger state transition)."""
self._check_state_transition()
return self._state
@property
def stats(self) -> CircuitStats:
"""Get circuit statistics."""
return self._stats
def _check_state_transition(self) -> None:
"""Check if state should transition based on time."""
if self._state == CircuitState.OPEN:
time_in_open = time.time() - self._stats.state_changed_at
if time_in_open >= self.recovery_timeout:
self._transition_to(CircuitState.HALF_OPEN)
logger.info(
f"Circuit {self.name} transitioned to HALF_OPEN "
f"after {time_in_open:.1f}s"
)
def _transition_to(self, new_state: CircuitState) -> None:
"""Transition to a new state."""
old_state = self._state
self._state = new_state
self._stats.state_changed_at = time.time()
if new_state == CircuitState.HALF_OPEN:
self._stats.half_open_calls = 0
elif new_state == CircuitState.CLOSED:
self._stats.failures = 0
logger.debug(f"Circuit {self.name}: {old_state.value} -> {new_state.value}")
def is_available(self) -> bool:
"""Check if circuit allows requests."""
state = self.state # Triggers state check
if state == CircuitState.CLOSED:
return True
if state == CircuitState.HALF_OPEN:
return self._stats.half_open_calls < self.half_open_max_calls
return False
def time_until_recovery(self) -> int | None:
"""Get seconds until circuit may recover (None if not open)."""
if self._state != CircuitState.OPEN:
return None
elapsed = time.time() - self._stats.state_changed_at
remaining = max(0, self.recovery_timeout - int(elapsed))
return remaining if remaining > 0 else 0
async def record_success(self) -> None:
"""Record a successful call."""
async with self._lock:
self._stats.successes += 1
self._stats.last_success_time = time.time()
if self._state == CircuitState.HALF_OPEN:
# Success in half-open state closes the circuit
self._transition_to(CircuitState.CLOSED)
logger.info(f"Circuit {self.name} closed after successful recovery")
async def record_failure(self, error: Exception | None = None) -> None: # noqa: ARG002
"""
Record a failed call.
Args:
error: The exception that occurred
"""
async with self._lock:
self._stats.failures += 1
self._stats.last_failure_time = time.time()
if self._state == CircuitState.HALF_OPEN:
# Failure in half-open state opens the circuit
self._transition_to(CircuitState.OPEN)
logger.warning(
f"Circuit {self.name} reopened after failure in half-open state"
)
elif self._state == CircuitState.CLOSED:
if self._stats.failures >= self.failure_threshold:
self._transition_to(CircuitState.OPEN)
logger.warning(
f"Circuit {self.name} opened after "
f"{self._stats.failures} failures"
)
async def execute(
self,
func: Callable[..., T],
*args: Any,
**kwargs: Any,
) -> T:
"""
Execute a function with circuit breaker protection.
Args:
func: Async function to execute
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Function result
Raises:
CircuitOpenError: If circuit is open
"""
if not self.is_available():
raise CircuitOpenError(
provider=self.name,
recovery_time=self.time_until_recovery(),
)
async with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._stats.half_open_calls += 1
try:
result = await func(*args, **kwargs)
await self.record_success()
return result
except Exception as e:
await self.record_failure(e)
raise
def reset(self) -> None:
"""Reset circuit to closed state."""
self._state = CircuitState.CLOSED
self._stats = CircuitStats()
logger.info(f"Circuit {self.name} reset to CLOSED")
def to_dict(self) -> dict[str, Any]:
"""Convert circuit state to dictionary."""
return {
"name": self.name,
"state": self._state.value,
"failures": self._stats.failures,
"successes": self._stats.successes,
"last_failure_time": self._stats.last_failure_time,
"last_success_time": self._stats.last_success_time,
"time_until_recovery": self.time_until_recovery(),
"is_available": self.is_available(),
}
class CircuitBreakerRegistry:
"""
Registry for managing multiple circuit breakers.
Provides centralized management of circuits for different providers/models.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""
Initialize registry.
Args:
settings: Application settings (uses default if None)
"""
self._settings = settings or get_settings()
self._circuits: dict[str, CircuitBreaker] = {}
self._lock = asyncio.Lock()
async def get_circuit(self, name: str) -> CircuitBreaker:
"""
Get or create a circuit breaker.
Args:
name: Circuit identifier (e.g., provider name)
Returns:
CircuitBreaker instance
"""
async with self._lock:
if name not in self._circuits:
self._circuits[name] = CircuitBreaker(
name=name,
failure_threshold=self._settings.circuit_failure_threshold,
recovery_timeout=self._settings.circuit_recovery_timeout,
half_open_max_calls=self._settings.circuit_half_open_max_calls,
)
return self._circuits[name]
def get_circuit_sync(self, name: str) -> CircuitBreaker:
"""
Get or create a circuit breaker (sync version).
Args:
name: Circuit identifier
Returns:
CircuitBreaker instance
"""
if name not in self._circuits:
self._circuits[name] = CircuitBreaker(
name=name,
failure_threshold=self._settings.circuit_failure_threshold,
recovery_timeout=self._settings.circuit_recovery_timeout,
half_open_max_calls=self._settings.circuit_half_open_max_calls,
)
return self._circuits[name]
async def is_available(self, name: str) -> bool:
"""
Check if a circuit is available.
Args:
name: Circuit identifier
Returns:
True if circuit allows requests
"""
circuit = await self.get_circuit(name)
return circuit.is_available()
async def record_success(self, name: str) -> None:
"""Record success for a circuit."""
circuit = await self.get_circuit(name)
await circuit.record_success()
async def record_failure(self, name: str, error: Exception | None = None) -> None:
"""Record failure for a circuit."""
circuit = await self.get_circuit(name)
await circuit.record_failure(error)
async def reset(self, name: str) -> None:
"""Reset a specific circuit."""
async with self._lock:
if name in self._circuits:
self._circuits[name].reset()
async def reset_all(self) -> None:
"""Reset all circuits."""
async with self._lock:
for circuit in self._circuits.values():
circuit.reset()
def get_all_states(self) -> dict[str, dict[str, Any]]:
"""Get state of all circuits."""
return {name: circuit.to_dict() for name, circuit in self._circuits.items()}
def get_open_circuits(self) -> list[str]:
"""Get list of circuits that are currently open."""
return [
name
for name, circuit in self._circuits.items()
if circuit.state == CircuitState.OPEN
]
def get_available_circuits(self) -> list[str]:
"""Get list of circuits that are available for requests."""
return [
name for name, circuit in self._circuits.items() if circuit.is_available()
]
# Global registry instance (lazy initialization)
_registry: CircuitBreakerRegistry | None = None
def get_circuit_registry() -> CircuitBreakerRegistry:
"""Get the global circuit breaker registry."""
global _registry
if _registry is None:
_registry = CircuitBreakerRegistry()
return _registry
def reset_circuit_registry() -> None:
"""Reset the global registry (for testing)."""
global _registry
_registry = None

View File

@@ -0,0 +1,442 @@
"""
Data models for LLM Gateway MCP Server.
Defines model groups, pricing, request/response structures.
Per ADR-004: LLM Provider Abstraction.
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class ModelGroup(str, Enum):
"""Model groups for routing LLM requests."""
REASONING = "reasoning" # Complex analysis, architecture decisions
CODE = "code" # Code writing and refactoring
FAST = "fast" # Quick tasks, simple queries
VISION = "vision" # Multimodal image analysis
EMBEDDING = "embedding" # Vector embeddings
COST_OPTIMIZED = "cost_optimized" # High-volume, non-critical
SELF_HOSTED = "self_hosted" # Privacy-sensitive, air-gapped
# Aliases for backward compatibility with ADR-004
HIGH_REASONING = "reasoning"
CODE_GENERATION = "code"
FAST_RESPONSE = "fast"
class Provider(str, Enum):
"""Supported LLM providers."""
ANTHROPIC = "anthropic"
OPENAI = "openai"
GOOGLE = "google"
ALIBABA = "alibaba"
DEEPSEEK = "deepseek"
@dataclass
class ModelConfig:
"""Configuration for a specific model."""
name: str # Model identifier (e.g., "claude-3-opus-20240229")
litellm_name: str # LiteLLM model string (e.g., "anthropic/claude-3-opus-20240229")
provider: Provider
cost_per_1m_input: float # USD per 1M input tokens
cost_per_1m_output: float # USD per 1M output tokens
context_window: int # Max context tokens
max_output_tokens: int # Max output tokens
supports_vision: bool = False
supports_streaming: bool = True
supports_function_calling: bool = True
# Model configurations per ADR-004
MODEL_CONFIGS: dict[str, ModelConfig] = {
# Anthropic models
"claude-opus-4": ModelConfig(
name="claude-opus-4",
litellm_name="anthropic/claude-sonnet-4-20250514", # Using sonnet-4 as opus-4 placeholder
provider=Provider.ANTHROPIC,
cost_per_1m_input=15.0,
cost_per_1m_output=75.0,
context_window=200000,
max_output_tokens=8192,
supports_vision=True,
),
"claude-sonnet-4": ModelConfig(
name="claude-sonnet-4",
litellm_name="anthropic/claude-sonnet-4-20250514",
provider=Provider.ANTHROPIC,
cost_per_1m_input=3.0,
cost_per_1m_output=15.0,
context_window=200000,
max_output_tokens=8192,
supports_vision=True,
),
"claude-haiku": ModelConfig(
name="claude-haiku",
litellm_name="anthropic/claude-3-5-haiku-20241022",
provider=Provider.ANTHROPIC,
cost_per_1m_input=1.0,
cost_per_1m_output=5.0,
context_window=200000,
max_output_tokens=8192,
supports_vision=True,
),
# OpenAI models
"gpt-4.1": ModelConfig(
name="gpt-4.1",
litellm_name="openai/gpt-4.1",
provider=Provider.OPENAI,
cost_per_1m_input=2.0,
cost_per_1m_output=8.0,
context_window=1047576,
max_output_tokens=32768,
supports_vision=True,
),
"gpt-4.1-mini": ModelConfig(
name="gpt-4.1-mini",
litellm_name="openai/gpt-4.1-mini",
provider=Provider.OPENAI,
cost_per_1m_input=0.4,
cost_per_1m_output=1.6,
context_window=1047576,
max_output_tokens=32768,
supports_vision=True,
),
# Google models
"gemini-2.5-pro": ModelConfig(
name="gemini-2.5-pro",
litellm_name="gemini/gemini-2.5-pro",
provider=Provider.GOOGLE,
cost_per_1m_input=1.25,
cost_per_1m_output=10.0,
context_window=1048576,
max_output_tokens=65536,
supports_vision=True,
),
"gemini-2.0-flash": ModelConfig(
name="gemini-2.0-flash",
litellm_name="gemini/gemini-2.0-flash",
provider=Provider.GOOGLE,
cost_per_1m_input=0.1,
cost_per_1m_output=0.4,
context_window=1048576,
max_output_tokens=8192,
supports_vision=True,
),
# Alibaba models
"qwen-max": ModelConfig(
name="qwen-max",
litellm_name="alibaba/qwen-max",
provider=Provider.ALIBABA,
cost_per_1m_input=2.0,
cost_per_1m_output=6.0,
context_window=32768,
max_output_tokens=8192,
supports_vision=False,
),
# DeepSeek models
"deepseek-coder": ModelConfig(
name="deepseek-coder",
litellm_name="deepseek/deepseek-coder",
provider=Provider.DEEPSEEK,
cost_per_1m_input=0.14,
cost_per_1m_output=0.28,
context_window=128000,
max_output_tokens=8192,
supports_vision=False,
),
"deepseek-chat": ModelConfig(
name="deepseek-chat",
litellm_name="deepseek/deepseek-chat",
provider=Provider.DEEPSEEK,
cost_per_1m_input=0.14,
cost_per_1m_output=0.28,
context_window=128000,
max_output_tokens=8192,
supports_vision=False,
),
# Embedding models
"text-embedding-3-large": ModelConfig(
name="text-embedding-3-large",
litellm_name="openai/text-embedding-3-large",
provider=Provider.OPENAI,
cost_per_1m_input=0.13,
cost_per_1m_output=0.0,
context_window=8191,
max_output_tokens=0,
supports_vision=False,
supports_streaming=False,
supports_function_calling=False,
),
"voyage-3": ModelConfig(
name="voyage-3",
litellm_name="voyage/voyage-3",
provider=Provider.ANTHROPIC, # Voyage via Anthropic partnership
cost_per_1m_input=0.06,
cost_per_1m_output=0.0,
context_window=32000,
max_output_tokens=0,
supports_vision=False,
supports_streaming=False,
supports_function_calling=False,
),
}
@dataclass
class ModelGroupConfig:
"""Configuration for a model group with failover chain."""
primary: str # Primary model name
fallbacks: list[str] # Fallback model names in order
description: str
def get_all_models(self) -> list[str]:
"""Get all models in priority order."""
return [self.primary, *self.fallbacks]
# Model group configurations per ADR-004
MODEL_GROUPS: dict[ModelGroup, ModelGroupConfig] = {
ModelGroup.REASONING: ModelGroupConfig(
primary="claude-opus-4",
fallbacks=["gpt-4.1", "gemini-2.5-pro", "qwen-max"],
description="Complex analysis, architecture decisions",
),
ModelGroup.CODE: ModelGroupConfig(
primary="claude-sonnet-4",
fallbacks=["gpt-4.1", "deepseek-coder"],
description="Code writing and refactoring",
),
ModelGroup.FAST: ModelGroupConfig(
primary="claude-haiku",
fallbacks=["gpt-4.1-mini", "gemini-2.0-flash"],
description="Quick tasks, simple queries",
),
ModelGroup.VISION: ModelGroupConfig(
primary="claude-sonnet-4",
fallbacks=["gpt-4.1", "gemini-2.5-pro"],
description="Multimodal image analysis",
),
ModelGroup.EMBEDDING: ModelGroupConfig(
primary="text-embedding-3-large",
fallbacks=["voyage-3"],
description="Vector embeddings",
),
ModelGroup.COST_OPTIMIZED: ModelGroupConfig(
primary="qwen-max",
fallbacks=["deepseek-chat"],
description="High-volume, non-critical tasks",
),
ModelGroup.SELF_HOSTED: ModelGroupConfig(
primary="deepseek-chat",
fallbacks=["qwen-max"],
description="Privacy-sensitive, air-gapped",
),
}
# Agent type to model group mapping per ADR-004
AGENT_TYPE_MODEL_PREFERENCES: dict[str, ModelGroup] = {
"product_owner": ModelGroup.REASONING,
"software_architect": ModelGroup.REASONING,
"software_engineer": ModelGroup.CODE,
"qa_engineer": ModelGroup.CODE,
"devops_engineer": ModelGroup.FAST,
"project_manager": ModelGroup.FAST,
"business_analyst": ModelGroup.REASONING,
}
class ChatMessage(BaseModel):
"""A single chat message."""
role: str = Field(..., description="Message role: system, user, assistant, tool")
content: str | list[dict[str, Any]] = Field(..., description="Message content")
name: str | None = Field(default=None, description="Optional name for the message")
tool_calls: list[dict[str, Any]] | None = Field(
default=None, description="Tool calls if role is assistant"
)
tool_call_id: str | None = Field(
default=None, description="Tool call ID if role is tool"
)
class CompletionRequest(BaseModel):
"""Request for chat completion."""
project_id: str = Field(..., description="Project ID for cost attribution")
agent_id: str = Field(..., description="Agent ID making the request")
messages: list[ChatMessage] = Field(..., description="Chat messages")
model_group: ModelGroup = Field(
default=ModelGroup.REASONING, description="Model group for routing"
)
model_override: str | None = Field(
default=None, description="Specific model to use (bypasses routing)"
)
max_tokens: int = Field(default=4096, ge=1, le=32768, description="Max output tokens")
temperature: float = Field(
default=0.7, ge=0.0, le=2.0, description="Sampling temperature"
)
stream: bool = Field(default=False, description="Enable streaming response")
session_id: str | None = Field(
default=None, description="Session ID for conversation tracking"
)
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
class UsageStats(BaseModel):
"""Token usage statistics."""
prompt_tokens: int = Field(default=0, description="Input tokens used")
completion_tokens: int = Field(default=0, description="Output tokens generated")
total_tokens: int = Field(default=0, description="Total tokens")
cost_usd: float = Field(default=0.0, description="Estimated cost in USD")
@classmethod
def from_response(
cls, prompt_tokens: int, completion_tokens: int, model_config: ModelConfig
) -> "UsageStats":
"""Create usage stats from token counts and model config."""
input_cost = (prompt_tokens / 1_000_000) * model_config.cost_per_1m_input
output_cost = (completion_tokens / 1_000_000) * model_config.cost_per_1m_output
return cls(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
cost_usd=round(input_cost + output_cost, 6),
)
class CompletionResponse(BaseModel):
"""Response from chat completion."""
id: str = Field(..., description="Unique response ID")
model: str = Field(..., description="Model that generated the response")
provider: str = Field(..., description="Provider used")
content: str = Field(..., description="Generated content")
finish_reason: str = Field(
default="stop", description="Reason for completion finish"
)
usage: UsageStats = Field(default_factory=UsageStats, description="Token usage")
created_at: datetime = Field(
default_factory=datetime.utcnow, description="Response timestamp"
)
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
class StreamChunk(BaseModel):
"""A chunk from a streaming response."""
id: str = Field(..., description="Chunk ID")
delta: str = Field(default="", description="Content delta")
finish_reason: str | None = Field(default=None, description="Finish reason if done")
usage: UsageStats | None = Field(
default=None, description="Usage stats (only on final chunk)"
)
class EmbeddingRequest(BaseModel):
"""Request for text embeddings."""
project_id: str = Field(..., description="Project ID for cost attribution")
agent_id: str = Field(..., description="Agent ID making the request")
texts: list[str] = Field(..., min_length=1, description="Texts to embed")
model: str = Field(
default="text-embedding-3-large", description="Embedding model"
)
class EmbeddingResponse(BaseModel):
"""Response from embedding generation."""
model: str = Field(..., description="Model used")
embeddings: list[list[float]] = Field(..., description="Embedding vectors")
usage: UsageStats = Field(default_factory=UsageStats, description="Token usage")
@dataclass
class CostRecord:
"""A single cost record for tracking."""
project_id: str
agent_id: str
model: str
prompt_tokens: int
completion_tokens: int
cost_usd: float
timestamp: datetime = field(default_factory=datetime.utcnow)
session_id: str | None = None
request_id: str | None = None
class UsageReport(BaseModel):
"""Usage report for a project or agent."""
entity_id: str = Field(..., description="Project or agent ID")
entity_type: str = Field(..., description="'project' or 'agent'")
period: str = Field(..., description="Report period (hour, day, month)")
period_start: datetime = Field(..., description="Period start time")
period_end: datetime = Field(..., description="Period end time")
total_requests: int = Field(default=0, description="Total requests")
total_tokens: int = Field(default=0, description="Total tokens used")
total_cost_usd: float = Field(default=0.0, description="Total cost in USD")
by_model: dict[str, dict[str, Any]] = Field(
default_factory=dict, description="Breakdown by model"
)
by_agent: dict[str, dict[str, Any]] = Field(
default_factory=dict, description="Breakdown by agent (for project reports)"
)
class ModelInfo(BaseModel):
"""Information about an available model."""
name: str = Field(..., description="Model name")
provider: str = Field(..., description="Provider name")
cost_per_1m_input: float = Field(..., description="Input cost per 1M tokens")
cost_per_1m_output: float = Field(..., description="Output cost per 1M tokens")
context_window: int = Field(..., description="Max context tokens")
max_output_tokens: int = Field(..., description="Max output tokens")
supports_vision: bool = Field(..., description="Vision capability")
supports_streaming: bool = Field(..., description="Streaming capability")
supports_function_calling: bool = Field(..., description="Function calling")
available: bool = Field(default=True, description="Provider configured")
@classmethod
def from_config(cls, config: ModelConfig, available: bool = True) -> "ModelInfo":
"""Create ModelInfo from ModelConfig."""
return cls(
name=config.name,
provider=config.provider.value,
cost_per_1m_input=config.cost_per_1m_input,
cost_per_1m_output=config.cost_per_1m_output,
context_window=config.context_window,
max_output_tokens=config.max_output_tokens,
supports_vision=config.supports_vision,
supports_streaming=config.supports_streaming,
supports_function_calling=config.supports_function_calling,
available=available,
)
class ModelGroupInfo(BaseModel):
"""Information about a model group."""
name: str = Field(..., description="Group name")
description: str = Field(..., description="Group description")
primary_model: str = Field(..., description="Primary model")
fallback_models: list[str] = Field(..., description="Fallback models")
available: bool = Field(default=True, description="At least one model available")

View File

@@ -0,0 +1,331 @@
"""
LiteLLM provider configuration for LLM Gateway.
Configures the LiteLLM Router with model lists and failover chains.
"""
import logging
import os
from typing import Any
import litellm
from litellm import Router
from config import Settings, get_settings
from models import (
MODEL_CONFIGS,
MODEL_GROUPS,
ModelConfig,
ModelGroup,
Provider,
)
logger = logging.getLogger(__name__)
def configure_litellm(settings: Settings) -> None:
"""
Configure LiteLLM global settings.
Args:
settings: Application settings
"""
# Set API keys in environment (LiteLLM reads from env)
if settings.anthropic_api_key:
os.environ["ANTHROPIC_API_KEY"] = settings.anthropic_api_key
if settings.openai_api_key:
os.environ["OPENAI_API_KEY"] = settings.openai_api_key
if settings.google_api_key:
os.environ["GEMINI_API_KEY"] = settings.google_api_key
if settings.alibaba_api_key:
os.environ["DASHSCOPE_API_KEY"] = settings.alibaba_api_key
if settings.deepseek_api_key:
os.environ["DEEPSEEK_API_KEY"] = settings.deepseek_api_key
# Configure LiteLLM settings
litellm.drop_params = True # Drop unsupported params instead of erroring
litellm.set_verbose = settings.debug
# Configure caching if enabled
if settings.litellm_cache_enabled:
litellm.cache = litellm.Cache(
type="redis",
host=_parse_redis_host(settings.redis_url),
port=_parse_redis_port(settings.redis_url),
ttl=settings.litellm_cache_ttl,
)
def _parse_redis_host(redis_url: str) -> str:
"""Extract host from Redis URL."""
# redis://host:port/db
url = redis_url.replace("redis://", "")
host_port = url.split("/")[0]
return host_port.split(":")[0]
def _parse_redis_port(redis_url: str) -> int:
"""Extract port from Redis URL."""
url = redis_url.replace("redis://", "")
host_port = url.split("/")[0]
parts = host_port.split(":")
return int(parts[1]) if len(parts) > 1 else 6379
def _is_provider_available(provider: Provider, settings: Settings) -> bool:
"""Check if a provider is available (API key configured)."""
provider_key_map = {
Provider.ANTHROPIC: settings.anthropic_api_key,
Provider.OPENAI: settings.openai_api_key,
Provider.GOOGLE: settings.google_api_key,
Provider.ALIBABA: settings.alibaba_api_key,
Provider.DEEPSEEK: settings.deepseek_api_key or settings.deepseek_base_url,
}
return bool(provider_key_map.get(provider))
def _build_model_entry(
model_config: ModelConfig,
settings: Settings,
) -> dict[str, Any] | None:
"""
Build a model entry for LiteLLM Router.
Args:
model_config: Model configuration
settings: Application settings
Returns:
Model entry dict or None if provider unavailable
"""
if not _is_provider_available(model_config.provider, settings):
logger.debug(
f"Skipping model {model_config.name}: "
f"{model_config.provider.value} provider not configured"
)
return None
entry: dict[str, Any] = {
"model_name": model_config.name,
"litellm_params": {
"model": model_config.litellm_name,
"timeout": settings.litellm_timeout,
"max_retries": settings.litellm_max_retries,
},
}
# Add custom base URL for DeepSeek self-hosted
if (
model_config.provider == Provider.DEEPSEEK
and settings.deepseek_base_url
):
entry["litellm_params"]["api_base"] = settings.deepseek_base_url
return entry
def build_model_list(settings: Settings | None = None) -> list[dict[str, Any]]:
"""
Build the complete model list for LiteLLM Router.
Args:
settings: Application settings (uses default if None)
Returns:
List of model entries for Router
"""
if settings is None:
settings = get_settings()
model_list: list[dict[str, Any]] = []
for model_config in MODEL_CONFIGS.values():
entry = _build_model_entry(model_config, settings)
if entry:
model_list.append(entry)
logger.info(f"Built model list with {len(model_list)} models")
return model_list
def build_fallback_config(settings: Settings | None = None) -> dict[str, list[str]]:
"""
Build fallback configuration based on model groups.
Args:
settings: Application settings (uses default if None)
Returns:
Dict mapping model names to their fallback chains
"""
if settings is None:
settings = get_settings()
fallbacks: dict[str, list[str]] = {}
for _group, config in MODEL_GROUPS.items():
# Get available models in this group's chain
available_models = []
for model_name in config.get_all_models():
model_config = MODEL_CONFIGS.get(model_name)
if model_config and _is_provider_available(model_config.provider, settings):
available_models.append(model_name)
if len(available_models) > 1:
# First model falls back to remaining models
fallbacks[available_models[0]] = available_models[1:]
return fallbacks
def get_available_models(settings: Settings | None = None) -> dict[str, ModelConfig]:
"""
Get all available models (with configured providers).
Args:
settings: Application settings (uses default if None)
Returns:
Dict of available model configs
"""
if settings is None:
settings = get_settings()
available: dict[str, ModelConfig] = {}
for name, config in MODEL_CONFIGS.items():
if _is_provider_available(config.provider, settings):
available[name] = config
return available
def get_available_model_groups(
settings: Settings | None = None,
) -> dict[ModelGroup, list[str]]:
"""
Get available models for each model group.
Args:
settings: Application settings (uses default if None)
Returns:
Dict mapping model groups to available models
"""
if settings is None:
settings = get_settings()
result: dict[ModelGroup, list[str]] = {}
for group, config in MODEL_GROUPS.items():
available_models = []
for model_name in config.get_all_models():
model_config = MODEL_CONFIGS.get(model_name)
if model_config and _is_provider_available(model_config.provider, settings):
available_models.append(model_name)
result[group] = available_models
return result
class LLMProvider:
"""
LLM Provider wrapper around LiteLLM Router.
Provides a high-level interface for LLM operations with
automatic failover and configuration management.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""
Initialize LLM Provider.
Args:
settings: Application settings (uses default if None)
"""
self._settings = settings or get_settings()
self._router: Router | None = None
self._initialized = False
def initialize(self) -> None:
"""Initialize the provider and LiteLLM Router."""
if self._initialized:
return
# Configure LiteLLM global settings
configure_litellm(self._settings)
# Build model list
model_list = build_model_list(self._settings)
if not model_list:
logger.warning("No models available - no providers configured")
self._initialized = True
return
# Build fallback config
fallbacks = build_fallback_config(self._settings)
# Create Router
self._router = Router(
model_list=model_list,
fallbacks=list(fallbacks.items()) if fallbacks else None,
routing_strategy="latency-based-routing",
num_retries=self._settings.litellm_max_retries,
timeout=self._settings.litellm_timeout,
retry_after=5, # Retry after 5 seconds
allowed_fails=2, # Fail after 2 consecutive failures
)
self._initialized = True
logger.info(
f"LLM Provider initialized with {len(model_list)} models, "
f"{len(fallbacks)} fallback chains"
)
@property
def router(self) -> Router | None:
"""Get the LiteLLM Router."""
if not self._initialized:
self.initialize()
return self._router
@property
def is_available(self) -> bool:
"""Check if provider is available."""
if not self._initialized:
self.initialize()
return self._router is not None
def get_model_config(self, model_name: str) -> ModelConfig | None:
"""Get configuration for a specific model."""
return MODEL_CONFIGS.get(model_name)
def get_available_models(self) -> dict[str, ModelConfig]:
"""Get all available models."""
return get_available_models(self._settings)
def is_model_available(self, model_name: str) -> bool:
"""Check if a specific model is available."""
model_config = MODEL_CONFIGS.get(model_name)
if not model_config:
return False
return _is_provider_available(model_config.provider, self._settings)
# Global provider instance (lazy initialization)
_provider: LLMProvider | None = None
def get_provider() -> LLMProvider:
"""Get the global LLM Provider instance."""
global _provider
if _provider is None:
_provider = LLMProvider()
return _provider
def reset_provider() -> None:
"""Reset the global provider (for testing)."""
global _provider
_provider = None

View File

@@ -4,20 +4,96 @@ version = "0.1.0"
description = "Syndarix LLM Gateway MCP Server - Unified LLM access with failover and cost tracking"
requires-python = ">=3.12"
dependencies = [
"fastmcp>=0.1.0",
"fastmcp>=2.0.0",
"litellm>=1.50.0",
"redis>=5.0.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"tiktoken>=0.7.0",
"httpx>=0.27.0",
"uvicorn>=0.30.0",
"fastapi>=0.115.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"pytest-asyncio>=0.24.0",
"pytest-cov>=5.0.0",
"respx>=0.21.0",
"fakeredis>=2.25.0",
"ruff>=0.8.0",
"mypy>=1.11.0",
]
[project.scripts]
llm-gateway = "server:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["."]
exclude = ["tests/", "*.md", "Dockerfile"]
[tool.hatch.build.targets.sdist]
include = ["*.py", "pyproject.toml"]
[tool.ruff]
target-version = "py312"
line-length = 88
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"ARG", # flake8-unused-arguments
"SIM", # flake8-simplify
]
ignore = [
"E501", # line too long (handled by formatter)
"B008", # do not perform function calls in argument defaults
"B904", # raise from in except (too noisy)
]
[tool.ruff.lint.isort]
known-first-party = ["config", "models", "exceptions", "providers", "failover", "routing", "cost_tracking", "streaming"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]
addopts = "-v --tb=short"
filterwarnings = [
"ignore::DeprecationWarning",
]
[tool.coverage.run]
source = ["."]
omit = ["tests/*", "conftest.py"]
branch = true
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise NotImplementedError",
"if TYPE_CHECKING:",
"if __name__ == .__main__.:",
]
fail_under = 90
show_missing = true
[tool.mypy]
python_version = "3.12"
strict = true
warn_return_any = true
warn_unused_ignores = true
disallow_untyped_defs = true
plugins = ["pydantic.mypy"]

View File

@@ -0,0 +1,319 @@
"""
Model routing for LLM Gateway.
Handles model selection based on:
- Model group configuration
- Circuit breaker availability
- Agent type preferences
"""
import logging
from typing import Any
from config import Settings, get_settings
from exceptions import (
AllProvidersFailedError,
InvalidModelError,
InvalidModelGroupError,
ModelNotAvailableError,
)
from failover import CircuitBreakerRegistry, get_circuit_registry
from models import (
AGENT_TYPE_MODEL_PREFERENCES,
MODEL_CONFIGS,
MODEL_GROUPS,
ModelConfig,
ModelGroup,
)
from providers import get_available_models
logger = logging.getLogger(__name__)
class ModelRouter:
"""
Routes requests to appropriate models based on configuration.
Considers:
- Model group preferences
- Circuit breaker states
- Agent type defaults
- Provider availability
"""
def __init__(
self,
settings: Settings | None = None,
circuit_registry: CircuitBreakerRegistry | None = None,
) -> None:
"""
Initialize model router.
Args:
settings: Application settings
circuit_registry: Circuit breaker registry
"""
self._settings = settings or get_settings()
self._circuit_registry = circuit_registry or get_circuit_registry()
def parse_model_group(self, group_str: str) -> ModelGroup:
"""
Parse model group from string.
Args:
group_str: Group name string
Returns:
ModelGroup enum value
Raises:
InvalidModelGroupError: If group is unknown
"""
# Handle aliases
aliases = {
"high-reasoning": ModelGroup.REASONING,
"high_reasoning": ModelGroup.REASONING,
"code-generation": ModelGroup.CODE,
"code_generation": ModelGroup.CODE,
"fast-response": ModelGroup.FAST,
"fast_response": ModelGroup.FAST,
}
# Try direct enum value
try:
return ModelGroup(group_str.lower())
except ValueError:
pass
# Try aliases
if group_str.lower() in aliases:
return aliases[group_str.lower()]
# Unknown group
available = [g.value for g in ModelGroup]
raise InvalidModelGroupError(
model_group=group_str,
available_groups=available,
)
def get_model_config(self, model_name: str) -> ModelConfig:
"""
Get configuration for a specific model.
Args:
model_name: Model name
Returns:
Model configuration
Raises:
InvalidModelError: If model is unknown
"""
config = MODEL_CONFIGS.get(model_name)
if not config:
raise InvalidModelError(
model=model_name,
reason="Unknown model",
)
return config
def get_preferred_group_for_agent(self, agent_type: str) -> ModelGroup:
"""
Get preferred model group for an agent type.
Args:
agent_type: Agent type identifier
Returns:
Preferred ModelGroup
"""
return AGENT_TYPE_MODEL_PREFERENCES.get(
agent_type.lower(),
ModelGroup.REASONING, # Default to reasoning
)
async def select_model(
self,
model_group: ModelGroup | str,
model_override: str | None = None,
agent_type: str | None = None,
) -> tuple[str, ModelConfig]:
"""
Select the best available model.
Args:
model_group: Desired model group
model_override: Specific model to use (bypasses group routing)
agent_type: Agent type for preference lookup
Returns:
Tuple of (model_name, model_config)
Raises:
InvalidModelError: If override model is invalid
InvalidModelGroupError: If group is invalid
AllProvidersFailedError: If no models are available
"""
# Handle model override
if model_override:
config = MODEL_CONFIGS.get(model_override)
if not config:
raise InvalidModelError(
model=model_override,
reason="Unknown model",
)
# Check if model's provider is available (using router's settings)
available_models = get_available_models(self._settings)
if model_override not in available_models:
raise ModelNotAvailableError(
model=model_override,
provider=config.provider.value,
)
# Check circuit breaker
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
if not circuit.is_available():
raise ModelNotAvailableError(
model=model_override,
provider=f"{config.provider.value} (circuit open)",
)
return model_override, config
# Parse model group if string
if isinstance(model_group, str):
model_group = self.parse_model_group(model_group)
# Get agent type preference if no explicit group
if agent_type:
preferred = self.get_preferred_group_for_agent(agent_type)
logger.debug(
f"Agent type {agent_type} prefers {preferred.value}, "
f"requested {model_group.value}"
)
# Get group configuration
group_config = MODEL_GROUPS.get(model_group)
if not group_config:
raise InvalidModelGroupError(
model_group=model_group.value,
available_groups=[g.value for g in ModelGroup],
)
# Get available models
available_models = get_available_models(self._settings)
# Try models in priority order
errors: list[dict[str, Any]] = []
attempted: list[str] = []
for model_name in group_config.get_all_models():
attempted.append(model_name)
# Check if model provider is configured
config = MODEL_CONFIGS.get(model_name)
if not config:
errors.append({"model": model_name, "error": "Unknown model"})
continue
if model_name not in available_models:
errors.append({
"model": model_name,
"error": f"Provider {config.provider.value} not configured",
})
continue
# Check circuit breaker
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
if not circuit.is_available():
errors.append({
"model": model_name,
"error": f"Circuit open for {config.provider.value}",
})
continue
# Model is available
logger.debug(
f"Selected model {model_name} for group {model_group.value}"
)
return model_name, config
# No models available
raise AllProvidersFailedError(
model_group=model_group.value,
attempted_models=attempted,
errors=errors,
)
async def get_available_models_for_group(
self,
model_group: ModelGroup | str,
) -> list[tuple[str, ModelConfig, bool]]:
"""
Get all models for a group with availability status.
Args:
model_group: Model group
Returns:
List of (model_name, config, is_available) tuples
"""
# Parse model group if string
if isinstance(model_group, str):
model_group = self.parse_model_group(model_group)
group_config = MODEL_GROUPS.get(model_group)
if not group_config:
return []
available_models = get_available_models(self._settings)
result: list[tuple[str, ModelConfig, bool]] = []
for model_name in group_config.get_all_models():
config = MODEL_CONFIGS.get(model_name)
if not config:
continue
is_available = model_name in available_models
if is_available:
# Also check circuit breaker
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
is_available = circuit.is_available()
result.append((model_name, config, is_available))
return result
def get_all_model_groups(self) -> dict[str, dict[str, Any]]:
"""
Get information about all model groups.
Returns:
Dict of group info
"""
result: dict[str, dict[str, Any]] = {}
for group, config in MODEL_GROUPS.items():
result[group.value] = {
"description": config.description,
"primary": config.primary,
"fallbacks": config.fallbacks,
}
return result
# Global router instance (lazy initialization)
_router: ModelRouter | None = None
def get_model_router() -> ModelRouter:
"""Get the global model router instance."""
global _router
if _router is None:
_router = ModelRouter()
return _router
def reset_model_router() -> None:
"""Reset the global router (for testing)."""
global _router
_router = None

View File

@@ -4,36 +4,555 @@ Syndarix LLM Gateway MCP Server.
Provides unified LLM access with:
- Multi-provider support (Claude, GPT, Gemini, Qwen, DeepSeek)
- Automatic failover chains
- Cost tracking via LiteLLM callbacks
- Model group routing (high-reasoning, code-generation, fast-response, cost-optimized)
- Cost tracking via Redis
- Model group routing (reasoning, code, fast, vision, embedding)
- Circuit breaker protection
Per ADR-004: LLM Provider Abstraction.
"""
import os
import logging
import uuid
from contextlib import asynccontextmanager
from typing import Any
import tiktoken
from fastapi import FastAPI
from fastmcp import FastMCP
# Create MCP server
mcp = FastMCP(
"syndarix-llm-gateway",
description="Unified LLM access with failover and cost tracking",
from config import get_settings
from cost_tracking import calculate_cost, get_cost_tracker
from exceptions import (
AllProvidersFailedError,
CircuitOpenError,
CostLimitExceededError,
InvalidModelError,
InvalidModelGroupError,
LLMGatewayError,
ModelNotAvailableError,
)
from failover import get_circuit_registry
from models import (
MODEL_CONFIGS,
MODEL_GROUPS,
ModelGroup,
)
from providers import get_provider
from routing import get_model_router
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Create FastMCP server
mcp = FastMCP("syndarix-llm-gateway")
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Application lifespan handler."""
settings = get_settings()
logger.info(f"Starting LLM Gateway on {settings.host}:{settings.port}")
# Initialize provider
provider = get_provider()
provider.initialize()
yield
# Cleanup
from cost_tracking import close_cost_tracker
await close_cost_tracker()
logger.info("LLM Gateway shutdown complete")
# Create FastAPI app that wraps FastMCP
app = FastAPI(
title="Syndarix LLM Gateway",
description="MCP Server for unified LLM access",
version="0.1.0",
lifespan=lifespan,
)
# Configuration
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
DATABASE_URL = os.getenv("DATABASE_URL")
# Health endpoint
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint."""
settings = get_settings()
provider = get_provider()
return {
"status": "healthy",
"service": "llm-gateway",
"providers_configured": settings.get_available_providers(),
"provider_available": provider.is_available,
}
# Tool discovery endpoint (for MCP client compatibility)
@app.get("/mcp/tools")
async def list_tools() -> dict[str, Any]:
"""List available MCP tools."""
return {
"tools": [
{
"name": "chat_completion",
"description": "Generate a chat completion using the specified model group",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string", "description": "Project ID"},
"agent_id": {"type": "string", "description": "Agent ID"},
"messages": {
"type": "array",
"items": {
"type": "object",
"properties": {
"role": {"type": "string"},
"content": {"type": "string"},
},
"required": ["role", "content"],
},
},
"model_group": {
"type": "string",
"enum": [g.value for g in ModelGroup],
"default": "reasoning",
},
"max_tokens": {"type": "integer", "default": 4096},
"temperature": {"type": "number", "default": 0.7},
"stream": {"type": "boolean", "default": False},
},
"required": ["project_id", "agent_id", "messages"],
},
},
{
"name": "list_models",
"description": "List available models and model groups",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string"},
"agent_id": {"type": "string"},
"model_group": {"type": "string"},
},
"required": ["project_id", "agent_id"],
},
},
{
"name": "get_usage",
"description": "Get usage statistics for a project or agent",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string"},
"agent_id": {"type": "string"},
"period": {
"type": "string",
"enum": ["hour", "day", "month"],
"default": "day",
},
},
"required": ["project_id", "agent_id"],
},
},
{
"name": "count_tokens",
"description": "Count tokens in text",
"inputSchema": {
"type": "object",
"properties": {
"project_id": {"type": "string"},
"agent_id": {"type": "string"},
"text": {"type": "string"},
"model": {"type": "string"},
},
"required": ["project_id", "agent_id", "text"],
},
},
]
}
# JSON-RPC endpoint (for MCP client compatibility)
@app.post("/mcp")
async def jsonrpc_handler(request: dict[str, Any]) -> dict[str, Any]:
"""Handle JSON-RPC 2.0 requests for MCP tools."""
# Validate JSON-RPC structure
if request.get("jsonrpc") != "2.0":
return {
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid JSON-RPC version"},
"id": request.get("id"),
}
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
# Handle tool calls
if method == "tools/call":
tool_name = params.get("name")
tool_args = params.get("arguments", {})
try:
if tool_name == "chat_completion":
result = await _impl_chat_completion(**tool_args)
elif tool_name == "list_models":
result = await _impl_list_models(**tool_args)
elif tool_name == "get_usage":
result = await _impl_get_usage(**tool_args)
elif tool_name == "count_tokens":
result = await _impl_count_tokens(**tool_args)
else:
return {
"jsonrpc": "2.0",
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
"id": request_id,
}
return {
"jsonrpc": "2.0",
"result": {"content": [{"type": "text", "text": str(result)}]},
"id": request_id,
}
except LLMGatewayError as e:
return {
"jsonrpc": "2.0",
"error": {"code": -32000, "message": str(e), "data": e.to_dict()},
"id": request_id,
}
except Exception as e:
logger.exception(f"Error executing tool {tool_name}")
return {
"jsonrpc": "2.0",
"error": {"code": -32603, "message": str(e)},
"id": request_id,
}
# Handle tool listing
elif method == "tools/list":
tools_response = await list_tools()
return {
"jsonrpc": "2.0",
"result": tools_response,
"id": request_id,
}
else:
return {
"jsonrpc": "2.0",
"error": {"code": -32601, "message": f"Unknown method: {method}"},
"id": request_id,
}
# ============================================================================
# Core Implementation Functions
# ============================================================================
async def _impl_chat_completion(
project_id: str,
agent_id: str,
messages: list[dict[str, Any]],
model_group: str = "reasoning",
model_override: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
stream: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
"""Core implementation for chat completion."""
settings = get_settings()
router = get_model_router()
tracker = get_cost_tracker()
circuit_registry = get_circuit_registry()
# Check budget before making request
if settings.cost_tracking_enabled:
within_budget, current_cost, limit = await tracker.check_budget(project_id)
if not within_budget:
raise CostLimitExceededError(
entity_type="project",
entity_id=project_id,
current_cost=current_cost,
limit=limit,
)
# Select model
try:
model_name, model_config = await router.select_model(
model_group=model_group,
model_override=model_override,
)
except (InvalidModelGroupError, InvalidModelError, AllProvidersFailedError):
raise
except ModelNotAvailableError:
raise
# Get provider
provider = get_provider()
if not provider.router:
raise AllProvidersFailedError(
model_group=model_group,
attempted_models=[model_name],
errors=[{"error": "No providers configured"}],
)
# Generate request ID
request_id = str(uuid.uuid4())
try:
# Get circuit breaker for this provider
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
# Make completion request
if stream:
# Return streaming response info
# Actual streaming would be handled by a separate endpoint
return {
"status": "streaming_not_supported_via_tool",
"message": "Use /stream endpoint for streaming responses",
"request_id": request_id,
}
# Non-streaming completion
response = await provider.router.acompletion(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
# Record success
await circuit.record_success()
# Extract response data
content = response.choices[0].message.content or ""
finish_reason = response.choices[0].finish_reason or "stop"
# Get usage stats
prompt_tokens = response.usage.prompt_tokens if response.usage else 0
completion_tokens = response.usage.completion_tokens if response.usage else 0
# Calculate cost
cost_usd = calculate_cost(model_name, prompt_tokens, completion_tokens)
# Record usage
await tracker.record_usage(
project_id=project_id,
agent_id=agent_id,
model=model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cost_usd=cost_usd,
session_id=session_id,
request_id=request_id,
)
return {
"id": request_id,
"model": model_name,
"provider": model_config.provider.value,
"content": content,
"finish_reason": finish_reason,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"cost_usd": cost_usd,
},
}
except CircuitOpenError:
raise
except Exception as e:
# Record failure
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
await circuit.record_failure(e)
logger.error(f"Completion failed: {e}")
raise AllProvidersFailedError(
model_group=model_group,
attempted_models=[model_name],
errors=[{"model": model_name, "error": str(e)}],
)
async def _impl_list_models(
project_id: str,
agent_id: str,
model_group: str | None = None,
) -> dict[str, Any]:
"""Core implementation for list_models."""
settings = get_settings()
provider = get_provider()
router = get_model_router()
# Get available providers
available_providers = settings.get_available_providers()
result: dict[str, Any] = {
"project_id": project_id,
"agent_id": agent_id,
"available_providers": available_providers,
}
if model_group:
# List models for specific group
try:
parsed_group = router.parse_model_group(model_group)
models = await router.get_available_models_for_group(parsed_group)
result["model_group"] = model_group
result["models"] = [
{
"name": name,
"provider": config.provider.value,
"available": available,
"cost_per_1m_input": config.cost_per_1m_input,
"cost_per_1m_output": config.cost_per_1m_output,
}
for name, config, available in models
]
except InvalidModelGroupError as e:
result["error"] = e.to_dict()
else:
# List all model groups
groups: dict[str, Any] = {}
for group in ModelGroup:
group_config = MODEL_GROUPS.get(group)
if group_config:
models = await router.get_available_models_for_group(group)
available_count = sum(1 for _, _, avail in models if avail)
groups[group.value] = {
"description": group_config.description,
"primary": group_config.primary,
"fallbacks": group_config.fallbacks,
"available_models": available_count,
"total_models": len(models),
}
result["model_groups"] = groups
# List all models
all_models: list[dict[str, Any]] = []
available_models = provider.get_available_models()
for name, config in MODEL_CONFIGS.items():
all_models.append({
"name": name,
"provider": config.provider.value,
"available": name in available_models,
"cost_per_1m_input": config.cost_per_1m_input,
"cost_per_1m_output": config.cost_per_1m_output,
"context_window": config.context_window,
"max_output_tokens": config.max_output_tokens,
"supports_vision": config.supports_vision,
"supports_streaming": config.supports_streaming,
})
result["models"] = all_models
return result
async def _impl_get_usage(
project_id: str,
agent_id: str,
period: str = "day",
) -> dict[str, Any]:
"""Core implementation for get_usage."""
tracker = get_cost_tracker()
# Get project usage
project_report = await tracker.get_project_usage(project_id, period=period)
# Get agent usage
agent_report = await tracker.get_agent_usage(agent_id, period=period)
return {
"project_id": project_id,
"agent_id": agent_id,
"period": period,
"project_usage": {
"total_requests": project_report.total_requests,
"total_tokens": project_report.total_tokens,
"total_cost_usd": project_report.total_cost_usd,
"by_model": project_report.by_model,
"period_start": project_report.period_start.isoformat(),
"period_end": project_report.period_end.isoformat(),
},
"agent_usage": {
"total_requests": agent_report.total_requests,
"total_tokens": agent_report.total_tokens,
"total_cost_usd": agent_report.total_cost_usd,
"by_model": agent_report.by_model,
"period_start": agent_report.period_start.isoformat(),
"period_end": agent_report.period_end.isoformat(),
},
}
async def _impl_count_tokens(
project_id: str,
agent_id: str,
text: str,
model: str | None = None,
) -> dict[str, Any]:
"""Core implementation for count_tokens."""
# Use tiktoken for token counting
# Default to cl100k_base (used by GPT-4, Claude, etc.)
try:
if model and model.startswith("gpt"):
encoding = tiktoken.encoding_for_model(model)
else:
encoding = tiktoken.get_encoding("cl100k_base")
token_count = len(encoding.encode(text))
except Exception as e:
logger.warning(f"Token counting failed: {e}, using estimate")
# Fallback: rough estimate of ~4 chars per token
token_count = len(text) // 4
# Estimate costs for different models
cost_estimates: dict[str, float] = {}
for model_name, config in MODEL_CONFIGS.items():
if config.cost_per_1m_input > 0:
cost = (token_count / 1_000_000) * config.cost_per_1m_input
cost_estimates[model_name] = round(cost, 6)
return {
"project_id": project_id,
"agent_id": agent_id,
"token_count": token_count,
"text_length": len(text),
"encoding": "cl100k_base",
"cost_estimates": cost_estimates,
}
# ============================================================================
# MCP Tools (wrappers around core implementations)
# ============================================================================
@mcp.tool()
async def chat_completion(
project_id: str,
agent_id: str,
messages: list[dict],
model_group: str = "high-reasoning",
messages: list[dict[str, Any]],
model_group: str = "reasoning",
model_override: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> dict:
stream: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
"""
Generate a chat completion using the specified model group.
@@ -41,108 +560,116 @@ async def chat_completion(
project_id: UUID of the project (required for cost attribution)
agent_id: UUID of the agent instance making the request
messages: List of message dicts with 'role' and 'content'
model_group: Model routing group (high-reasoning, code-generation, fast-response, cost-optimized, self-hosted)
model_group: Model routing group (reasoning, code, fast, vision, embedding)
model_override: Specific model to use (bypasses group routing)
max_tokens: Maximum tokens to generate
temperature: Sampling temperature (0.0-2.0)
stream: Enable streaming response
session_id: Optional session ID for tracking
Returns:
Completion response with content and usage statistics
"""
# TODO: Implement with LiteLLM
# 1. Map model_group to primary model + fallbacks
# 2. Check project budget before making request
# 3. Make completion request with failover
# 4. Log usage via callback
# 5. Return response
return {
"status": "not_implemented",
"project_id": project_id,
"agent_id": agent_id,
"model_group": model_group,
}
return await _impl_chat_completion(
project_id=project_id,
agent_id=agent_id,
messages=messages,
model_group=model_group,
model_override=model_override,
max_tokens=max_tokens,
temperature=temperature,
stream=stream,
session_id=session_id,
)
@mcp.tool()
async def get_embeddings(
async def list_models(
project_id: str,
texts: list[str],
model: str = "text-embedding-3-small",
) -> dict:
agent_id: str,
model_group: str | None = None,
) -> dict[str, Any]:
"""
Generate embeddings for the given texts.
Args:
project_id: UUID of the project (required for cost attribution)
texts: List of texts to embed
model: Embedding model to use
Returns:
List of embedding vectors
"""
# TODO: Implement with LiteLLM embeddings
return {
"status": "not_implemented",
"project_id": project_id,
"text_count": len(texts),
}
@mcp.tool()
async def get_budget_status(project_id: str) -> dict:
"""
Get current budget status for a project.
List available models and model groups.
Args:
project_id: UUID of the project
agent_id: UUID of the agent instance
model_group: Optional specific group to list
Returns:
Budget status with usage, limits, and percentage
Dictionary of available models and groups
"""
# TODO: Implement budget check from Redis
return {
"status": "not_implemented",
"project_id": project_id,
}
return await _impl_list_models(
project_id=project_id,
agent_id=agent_id,
model_group=model_group,
)
@mcp.tool()
async def list_available_models() -> dict:
async def get_usage(
project_id: str,
agent_id: str,
period: str = "day",
) -> dict[str, Any]:
"""
List all available models and their capabilities.
Get usage statistics for a project or agent.
Args:
project_id: UUID of the project
agent_id: UUID of the agent
period: Time period (hour, day, month)
Returns:
Dictionary of model groups and available models
Usage statistics including tokens and costs
"""
return {
"model_groups": {
"high-reasoning": {
"primary": "claude-opus-4-5",
"fallbacks": ["gpt-5.1-codex-max", "gemini-3-pro"],
"description": "Complex analysis, architecture decisions",
},
"code-generation": {
"primary": "gpt-5.1-codex-max",
"fallbacks": ["claude-opus-4-5", "deepseek-v3.2"],
"description": "Code writing and refactoring",
},
"fast-response": {
"primary": "gemini-3-flash",
"fallbacks": ["qwen3-235b", "deepseek-v3.2"],
"description": "Quick tasks, simple queries",
},
"cost-optimized": {
"primary": "qwen3-235b",
"fallbacks": ["deepseek-v3.2"],
"description": "High-volume, non-critical tasks",
},
"self-hosted": {
"primary": "deepseek-v3.2",
"fallbacks": ["qwen3-235b"],
"description": "Privacy-sensitive, air-gapped",
},
}
}
return await _impl_get_usage(
project_id=project_id,
agent_id=agent_id,
period=period,
)
@mcp.tool()
async def count_tokens(
project_id: str,
agent_id: str,
text: str,
model: str | None = None,
) -> dict[str, Any]:
"""
Count tokens in text.
Args:
project_id: UUID of the project
agent_id: UUID of the agent
text: Text to count tokens in
model: Optional model for tokenizer selection
Returns:
Token count and estimation details
"""
return await _impl_count_tokens(
project_id=project_id,
agent_id=agent_id,
text=text,
model=model,
)
def main() -> None:
"""Run the server."""
import uvicorn
settings = get_settings()
uvicorn.run(
"server:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
)
if __name__ == "__main__":
mcp.run()
main()

View File

@@ -0,0 +1,344 @@
"""
Streaming support for LLM Gateway.
Provides async streaming wrappers for LiteLLM responses.
"""
import asyncio
import json
import logging
import uuid
from collections.abc import AsyncIterator
from typing import Any
from models import StreamChunk, UsageStats
logger = logging.getLogger(__name__)
class StreamAccumulator:
"""
Accumulates streaming chunks for cost calculation.
Tracks:
- Full content for final response
- Token counts from chunks
- Timing information
"""
def __init__(self, request_id: str | None = None) -> None:
"""
Initialize accumulator.
Args:
request_id: Optional request ID for tracking
"""
self.request_id = request_id or str(uuid.uuid4())
self.content_parts: list[str] = []
self.chunks_received = 0
self.prompt_tokens = 0
self.completion_tokens = 0
self.model: str | None = None
self.finish_reason: str | None = None
self._started_at: float | None = None
self._finished_at: float | None = None
@property
def content(self) -> str:
"""Get accumulated content."""
return "".join(self.content_parts)
@property
def total_tokens(self) -> int:
"""Get total token count."""
return self.prompt_tokens + self.completion_tokens
@property
def duration_seconds(self) -> float | None:
"""Get stream duration in seconds."""
if self._started_at is None or self._finished_at is None:
return None
return self._finished_at - self._started_at
def start(self) -> None:
"""Mark stream start."""
import time
self._started_at = time.time()
def finish(self) -> None:
"""Mark stream finish."""
import time
self._finished_at = time.time()
def add_chunk(
self,
delta: str,
finish_reason: str | None = None,
model: str | None = None,
usage: dict[str, int] | None = None,
) -> None:
"""
Add a chunk to the accumulator.
Args:
delta: Content delta
finish_reason: Finish reason if this is the final chunk
model: Model name
usage: Usage stats if provided
"""
if delta:
self.content_parts.append(delta)
self.chunks_received += 1
if finish_reason:
self.finish_reason = finish_reason
if model:
self.model = model
if usage:
self.prompt_tokens = usage.get("prompt_tokens", self.prompt_tokens)
self.completion_tokens = usage.get(
"completion_tokens", self.completion_tokens
)
def get_usage_stats(self, cost_usd: float = 0.0) -> UsageStats:
"""Get usage statistics."""
return UsageStats(
prompt_tokens=self.prompt_tokens,
completion_tokens=self.completion_tokens,
total_tokens=self.total_tokens,
cost_usd=cost_usd,
)
async def wrap_litellm_stream(
stream: AsyncIterator[Any],
accumulator: StreamAccumulator | None = None,
) -> AsyncIterator[StreamChunk]:
"""
Wrap a LiteLLM stream into StreamChunk objects.
Args:
stream: LiteLLM async stream
accumulator: Optional accumulator for tracking
Yields:
StreamChunk objects
"""
if accumulator:
accumulator.start()
chunk_id = 0
try:
async for chunk in stream:
chunk_id += 1
# Extract data from LiteLLM chunk
delta = ""
finish_reason = None
usage = None
model = None
# Handle different chunk formats
if hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta"):
delta = getattr(choice.delta, "content", "") or ""
finish_reason = getattr(choice, "finish_reason", None)
if hasattr(chunk, "model"):
model = chunk.model
if hasattr(chunk, "usage") and chunk.usage:
usage = {
"prompt_tokens": getattr(chunk.usage, "prompt_tokens", 0),
"completion_tokens": getattr(chunk.usage, "completion_tokens", 0),
}
# Update accumulator
if accumulator:
accumulator.add_chunk(
delta=delta,
finish_reason=finish_reason,
model=model,
usage=usage,
)
# Create StreamChunk
stream_chunk = StreamChunk(
id=f"{accumulator.request_id if accumulator else 'stream'}-{chunk_id}",
delta=delta,
finish_reason=finish_reason,
)
# Add usage on final chunk
if finish_reason and accumulator:
stream_chunk.usage = accumulator.get_usage_stats()
yield stream_chunk
finally:
if accumulator:
accumulator.finish()
def format_sse_chunk(chunk: StreamChunk) -> str:
"""
Format a StreamChunk as SSE data.
Args:
chunk: StreamChunk to format
Returns:
SSE-formatted string
"""
data = {
"id": chunk.id,
"delta": chunk.delta,
}
if chunk.finish_reason:
data["finish_reason"] = chunk.finish_reason
if chunk.usage:
data["usage"] = chunk.usage.model_dump()
return f"data: {json.dumps(data)}\n\n"
def format_sse_done() -> str:
"""Format SSE done message."""
return "data: [DONE]\n\n"
def format_sse_error(error: str, code: str | None = None) -> str:
"""
Format an error as SSE data.
Args:
error: Error message
code: Error code
Returns:
SSE-formatted error string
"""
data = {"error": error}
if code:
data["code"] = code
return f"data: {json.dumps(data)}\n\n"
class StreamBuffer:
"""
Buffer for streaming responses with backpressure handling.
Useful when producing chunks faster than they can be consumed.
"""
def __init__(self, max_size: int = 100) -> None:
"""
Initialize buffer.
Args:
max_size: Maximum buffer size
"""
self._queue: asyncio.Queue[StreamChunk | None] = asyncio.Queue(maxsize=max_size)
self._done = False
self._error: Exception | None = None
async def put(self, chunk: StreamChunk) -> None:
"""
Put a chunk in the buffer.
Args:
chunk: Chunk to buffer
"""
if self._done:
raise RuntimeError("Buffer is closed")
await self._queue.put(chunk)
async def done(self) -> None:
"""Signal that streaming is complete."""
self._done = True
await self._queue.put(None)
async def error(self, err: Exception) -> None:
"""Signal an error occurred."""
self._error = err
self._done = True
await self._queue.put(None)
async def __aiter__(self) -> AsyncIterator[StreamChunk]:
"""Iterate over buffered chunks."""
while True:
chunk = await self._queue.get()
if chunk is None:
if self._error:
raise self._error
return
yield chunk
async def stream_to_string(stream: AsyncIterator[StreamChunk]) -> tuple[str, UsageStats | None]:
"""
Consume a stream and return full content.
Args:
stream: Stream to consume
Returns:
Tuple of (content, usage_stats)
"""
parts: list[str] = []
usage: UsageStats | None = None
async for chunk in stream:
if chunk.delta:
parts.append(chunk.delta)
if chunk.usage:
usage = chunk.usage
return "".join(parts), usage
async def merge_streams(
*streams: AsyncIterator[StreamChunk],
) -> AsyncIterator[StreamChunk]:
"""
Merge multiple streams into one.
Useful for parallel requests where results should be combined.
Args:
*streams: Streams to merge
Yields:
Chunks from all streams in arrival order
"""
pending: set[asyncio.Task[tuple[int, StreamChunk | None]]] = set()
async def next_chunk(
idx: int, stream: AsyncIterator[StreamChunk]
) -> tuple[int, StreamChunk | None]:
try:
return idx, await stream.__anext__()
except StopAsyncIteration:
return idx, None
# Start initial tasks
active_streams = list(streams)
for idx, stream in enumerate(active_streams):
task = asyncio.create_task(next_chunk(idx, stream))
pending.add(task)
while pending:
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
for task in done:
idx, chunk = task.result()
if chunk is not None:
yield chunk
# Schedule next chunk from this stream
new_task = asyncio.create_task(next_chunk(idx, active_streams[idx]))
pending.add(new_task)

View File

@@ -0,0 +1 @@
"""Tests for LLM Gateway MCP Server."""

View File

@@ -0,0 +1,204 @@
"""
Pytest fixtures for LLM Gateway tests.
"""
import os
import sys
from collections.abc import AsyncIterator, Iterator
from typing import Any
from unittest.mock import MagicMock, patch
import fakeredis.aioredis
import pytest
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import Settings
from cost_tracking import CostTracker, reset_cost_tracker
from failover import CircuitBreakerRegistry, reset_circuit_registry
from providers import LLMProvider, reset_provider
from routing import ModelRouter, reset_model_router
@pytest.fixture
def test_settings() -> Settings:
"""Create test settings with mock API keys."""
return Settings(
host="127.0.0.1",
port=8001,
debug=True,
redis_url="redis://localhost:6379/0",
anthropic_api_key="test-anthropic-key",
openai_api_key="test-openai-key",
google_api_key="test-google-key",
litellm_timeout=30,
litellm_max_retries=2,
litellm_cache_enabled=False,
cost_tracking_enabled=True,
circuit_failure_threshold=3,
circuit_recovery_timeout=10,
)
@pytest.fixture
def settings_no_providers() -> Settings:
"""Create settings with no providers configured."""
return Settings(
host="127.0.0.1",
port=8001,
debug=False,
redis_url="redis://localhost:6379/0",
anthropic_api_key=None,
openai_api_key=None,
google_api_key=None,
alibaba_api_key=None,
deepseek_api_key=None,
)
@pytest.fixture
def fake_redis() -> fakeredis.aioredis.FakeRedis:
"""Create a fake Redis instance for testing."""
return fakeredis.aioredis.FakeRedis(decode_responses=True)
@pytest.fixture
async def cost_tracker(
fake_redis: fakeredis.aioredis.FakeRedis,
test_settings: Settings,
) -> AsyncIterator[CostTracker]:
"""Create a cost tracker with fake Redis."""
reset_cost_tracker()
tracker = CostTracker(redis_client=fake_redis, settings=test_settings)
yield tracker
await tracker.close()
reset_cost_tracker()
@pytest.fixture
def circuit_registry(test_settings: Settings) -> Iterator[CircuitBreakerRegistry]:
"""Create a circuit breaker registry for testing."""
reset_circuit_registry()
registry = CircuitBreakerRegistry(settings=test_settings)
yield registry
reset_circuit_registry()
@pytest.fixture
def model_router(
test_settings: Settings,
circuit_registry: CircuitBreakerRegistry,
) -> Iterator[ModelRouter]:
"""Create a model router for testing."""
reset_model_router()
router = ModelRouter(settings=test_settings, circuit_registry=circuit_registry)
yield router
reset_model_router()
@pytest.fixture
def llm_provider(test_settings: Settings) -> Iterator[LLMProvider]:
"""Create an LLM provider for testing."""
reset_provider()
provider = LLMProvider(settings=test_settings)
yield provider
reset_provider()
@pytest.fixture
def mock_litellm_response() -> dict[str, Any]:
"""Create a mock LiteLLM response."""
return {
"id": "test-response-id",
"model": "claude-opus-4",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a test response.",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
},
}
@pytest.fixture
def mock_completion_response() -> MagicMock:
"""Create a mock completion response object."""
response = MagicMock()
response.id = "test-response-id"
response.model = "claude-opus-4"
choice = MagicMock()
choice.index = 0
choice.message = MagicMock()
choice.message.content = "This is a test response."
choice.finish_reason = "stop"
response.choices = [choice]
response.usage = MagicMock()
response.usage.prompt_tokens = 10
response.usage.completion_tokens = 20
response.usage.total_tokens = 30
return response
@pytest.fixture
def sample_messages() -> list[dict[str, str]]:
"""Sample chat messages for testing."""
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
]
@pytest.fixture
def sample_project_id() -> str:
"""Sample project ID."""
return "proj-12345678-1234-1234-1234-123456789abc"
@pytest.fixture
def sample_agent_id() -> str:
"""Sample agent ID."""
return "agent-87654321-4321-4321-4321-cba987654321"
@pytest.fixture
def sample_session_id() -> str:
"""Sample session ID."""
return "session-11111111-2222-3333-4444-555555555555"
# Reset all global state after each test
@pytest.fixture(autouse=True)
def reset_globals() -> Iterator[None]:
"""Reset all global state after each test."""
yield
reset_cost_tracker()
reset_circuit_registry()
reset_model_router()
reset_provider()
# Mock environment variables for tests
# Note: Not autouse=True to avoid affecting default value tests
@pytest.fixture
def mock_env_vars() -> Iterator[None]:
"""Set test environment variables."""
env_vars = {
"LLM_GATEWAY_HOST": "127.0.0.1",
"LLM_GATEWAY_PORT": "8001",
"LLM_GATEWAY_DEBUG": "true",
}
with patch.dict(os.environ, env_vars, clear=False):
yield

View File

@@ -0,0 +1,200 @@
"""
Tests for config module.
"""
import os
from unittest.mock import patch
import pytest
from config import Settings, get_settings
class TestSettings:
"""Tests for Settings class."""
def test_default_values(self) -> None:
"""Test default configuration values."""
settings = Settings()
assert settings.host == "0.0.0.0"
assert settings.port == 8001
assert settings.debug is False
assert settings.redis_url == "redis://localhost:6379/0"
assert settings.litellm_timeout == 120
assert settings.circuit_failure_threshold == 5
def test_custom_values(self) -> None:
"""Test custom configuration values."""
settings = Settings(
host="127.0.0.1",
port=9000,
debug=True,
redis_url="redis://custom:6380/1",
litellm_timeout=60,
)
assert settings.host == "127.0.0.1"
assert settings.port == 9000
assert settings.debug is True
assert settings.redis_url == "redis://custom:6380/1"
assert settings.litellm_timeout == 60
def test_port_validation_valid(self) -> None:
"""Test valid port numbers."""
settings = Settings(port=1)
assert settings.port == 1
settings = Settings(port=65535)
assert settings.port == 65535
settings = Settings(port=8080)
assert settings.port == 8080
def test_port_validation_invalid(self) -> None:
"""Test invalid port numbers."""
with pytest.raises(ValueError, match="Port must be between"):
Settings(port=0)
with pytest.raises(ValueError, match="Port must be between"):
Settings(port=65536)
with pytest.raises(ValueError, match="Port must be between"):
Settings(port=-1)
def test_ttl_validation_valid(self) -> None:
"""Test valid TTL values."""
settings = Settings(redis_ttl_hours=1)
assert settings.redis_ttl_hours == 1
settings = Settings(redis_ttl_hours=168) # 1 week
assert settings.redis_ttl_hours == 168
def test_ttl_validation_invalid(self) -> None:
"""Test invalid TTL values."""
with pytest.raises(ValueError, match="Redis TTL must be positive"):
Settings(redis_ttl_hours=0)
with pytest.raises(ValueError, match="Redis TTL must be positive"):
Settings(redis_ttl_hours=-1)
def test_failure_threshold_validation(self) -> None:
"""Test circuit failure threshold validation."""
settings = Settings(circuit_failure_threshold=1)
assert settings.circuit_failure_threshold == 1
settings = Settings(circuit_failure_threshold=100)
assert settings.circuit_failure_threshold == 100
with pytest.raises(ValueError, match="Failure threshold must be between"):
Settings(circuit_failure_threshold=0)
with pytest.raises(ValueError, match="Failure threshold must be between"):
Settings(circuit_failure_threshold=101)
def test_timeout_validation(self) -> None:
"""Test timeout validation."""
settings = Settings(litellm_timeout=1)
assert settings.litellm_timeout == 1
settings = Settings(litellm_timeout=600)
assert settings.litellm_timeout == 600
with pytest.raises(ValueError, match="Timeout must be between"):
Settings(litellm_timeout=0)
with pytest.raises(ValueError, match="Timeout must be between"):
Settings(litellm_timeout=601)
def test_get_available_providers_none(self) -> None:
"""Test getting available providers with none configured."""
settings = Settings()
providers = settings.get_available_providers()
assert providers == []
def test_get_available_providers_some(self) -> None:
"""Test getting available providers with some configured."""
settings = Settings(
anthropic_api_key="test-key",
openai_api_key="test-key",
)
providers = settings.get_available_providers()
assert "anthropic" in providers
assert "openai" in providers
assert "google" not in providers
assert len(providers) == 2
def test_get_available_providers_all(self) -> None:
"""Test getting available providers with all configured."""
settings = Settings(
anthropic_api_key="test-key",
openai_api_key="test-key",
google_api_key="test-key",
alibaba_api_key="test-key",
deepseek_api_key="test-key",
)
providers = settings.get_available_providers()
assert len(providers) == 5
assert "anthropic" in providers
assert "openai" in providers
assert "google" in providers
assert "alibaba" in providers
assert "deepseek" in providers
def test_has_any_provider_false(self) -> None:
"""Test has_any_provider when none configured."""
settings = Settings()
assert settings.has_any_provider() is False
def test_has_any_provider_true(self) -> None:
"""Test has_any_provider when at least one configured."""
settings = Settings(anthropic_api_key="test-key")
assert settings.has_any_provider() is True
def test_deepseek_base_url_counts_as_provider(self) -> None:
"""Test that DeepSeek base URL alone counts as provider."""
settings = Settings(deepseek_base_url="http://localhost:8000")
providers = settings.get_available_providers()
assert "deepseek" in providers
class TestGetSettings:
"""Tests for get_settings function."""
def test_get_settings_returns_settings(self) -> None:
"""Test that get_settings returns a Settings instance."""
# Clear the cache first
get_settings.cache_clear()
settings = get_settings()
assert isinstance(settings, Settings)
def test_get_settings_is_cached(self) -> None:
"""Test that get_settings returns cached instance."""
get_settings.cache_clear()
settings1 = get_settings()
settings2 = get_settings()
assert settings1 is settings2
def test_env_var_override(self) -> None:
"""Test that environment variables override defaults."""
get_settings.cache_clear()
with patch.dict(
os.environ,
{
"LLM_GATEWAY_HOST": "192.168.1.1",
"LLM_GATEWAY_PORT": "9999",
"LLM_GATEWAY_DEBUG": "true",
},
):
get_settings.cache_clear()
settings = get_settings()
assert settings.host == "192.168.1.1"
assert settings.port == 9999
assert settings.debug is True

View File

@@ -0,0 +1,417 @@
"""
Tests for cost_tracking module.
"""
import asyncio
import fakeredis.aioredis
import pytest
from config import Settings
from cost_tracking import (
CostTracker,
calculate_cost,
close_cost_tracker,
get_cost_tracker,
reset_cost_tracker,
)
@pytest.fixture
def tracker_settings() -> Settings:
"""Settings for cost tracker tests."""
return Settings(
redis_url="redis://localhost:6379/0",
redis_prefix="test_llm_gateway",
cost_tracking_enabled=True,
cost_alert_threshold=100.0,
default_budget_limit=1000.0,
)
@pytest.fixture
def fake_redis() -> fakeredis.aioredis.FakeRedis:
"""Create fake Redis for testing."""
return fakeredis.aioredis.FakeRedis(decode_responses=True)
@pytest.fixture
def tracker(
fake_redis: fakeredis.aioredis.FakeRedis,
tracker_settings: Settings,
) -> CostTracker:
"""Create cost tracker with fake Redis."""
return CostTracker(redis_client=fake_redis, settings=tracker_settings)
class TestCalculateCost:
"""Tests for calculate_cost function."""
def test_calculate_cost_known_model(self) -> None:
"""Test calculating cost for known model."""
# claude-opus-4: $15/1M input, $75/1M output
cost = calculate_cost(
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
)
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
assert cost == pytest.approx(0.0525, rel=0.001)
def test_calculate_cost_unknown_model(self) -> None:
"""Test calculating cost for unknown model."""
cost = calculate_cost(
model="unknown-model",
prompt_tokens=1000,
completion_tokens=500,
)
assert cost == 0.0
def test_calculate_cost_zero_tokens(self) -> None:
"""Test calculating cost with zero tokens."""
cost = calculate_cost(
model="claude-opus-4",
prompt_tokens=0,
completion_tokens=0,
)
assert cost == 0.0
def test_calculate_cost_large_token_counts(self) -> None:
"""Test calculating cost with large token counts."""
cost = calculate_cost(
model="claude-opus-4",
prompt_tokens=1_000_000,
completion_tokens=500_000,
)
# 1M * 15/1M + 500K * 75/1M = 15 + 37.5 = 52.5
assert cost == pytest.approx(52.5, rel=0.001)
class TestCostTracker:
"""Tests for CostTracker class."""
def test_record_usage(self, tracker: CostTracker) -> None:
"""Test recording usage."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
# Verify by getting usage report
report = asyncio.run(
tracker.get_project_usage("proj-123", period="day")
)
assert report.total_requests == 1
assert report.total_cost_usd == pytest.approx(0.01, rel=0.01)
def test_record_usage_disabled(self, tracker_settings: Settings) -> None:
"""Test recording is skipped when disabled."""
settings = Settings(**{
**tracker_settings.model_dump(),
"cost_tracking_enabled": False,
})
fake_redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
disabled_tracker = CostTracker(redis_client=fake_redis, settings=settings)
# This should not raise and should not record
asyncio.run(
disabled_tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
# Usage should be empty
report = asyncio.run(
disabled_tracker.get_project_usage("proj-123", period="day")
)
assert report.total_requests == 0
def test_record_usage_with_session(self, tracker: CostTracker) -> None:
"""Test recording usage with session ID."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
session_id="session-789",
)
)
# Verify session usage
session_usage = asyncio.run(
tracker.get_session_usage("session-789")
)
assert session_usage["session_id"] == "session-789"
assert session_usage["total_cost_usd"] == pytest.approx(0.01, rel=0.01)
def test_get_project_usage_empty(self, tracker: CostTracker) -> None:
"""Test getting usage for project with no data."""
report = asyncio.run(
tracker.get_project_usage("nonexistent-project", period="day")
)
assert report.entity_id == "nonexistent-project"
assert report.entity_type == "project"
assert report.total_requests == 0
assert report.total_cost_usd == 0.0
def test_get_project_usage_multiple_models(self, tracker: CostTracker) -> None:
"""Test usage tracking across multiple models."""
# Record usage for different models
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="gpt-4.1",
prompt_tokens=200,
completion_tokens=100,
cost_usd=0.02,
)
)
report = asyncio.run(
tracker.get_project_usage("proj-123", period="day")
)
assert report.total_requests == 2
assert len(report.by_model) == 2
assert "claude-opus-4" in report.by_model
assert "gpt-4.1" in report.by_model
def test_get_agent_usage(self, tracker: CostTracker) -> None:
"""Test getting agent usage."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
report = asyncio.run(
tracker.get_agent_usage("agent-456", period="day")
)
assert report.entity_id == "agent-456"
assert report.entity_type == "agent"
assert report.total_requests == 1
def test_usage_periods(self, tracker: CostTracker) -> None:
"""Test different usage periods."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
)
# Check different periods
hour_report = asyncio.run(
tracker.get_project_usage("proj-123", period="hour")
)
day_report = asyncio.run(
tracker.get_project_usage("proj-123", period="day")
)
month_report = asyncio.run(
tracker.get_project_usage("proj-123", period="month")
)
assert hour_report.period == "hour"
assert day_report.period == "day"
assert month_report.period == "month"
# All should have the same data
assert hour_report.total_requests == 1
assert day_report.total_requests == 1
assert month_report.total_requests == 1
def test_check_budget_within(self, tracker: CostTracker) -> None:
"""Test budget check when within limit."""
# Record some usage
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=50.0,
)
)
within, current, limit = asyncio.run(
tracker.check_budget("proj-123", budget_limit=100.0)
)
assert within is True
assert current == pytest.approx(50.0, rel=0.01)
assert limit == 100.0
def test_check_budget_exceeded(self, tracker: CostTracker) -> None:
"""Test budget check when exceeded."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=150.0,
)
)
within, current, limit = asyncio.run(
tracker.check_budget("proj-123", budget_limit=100.0)
)
assert within is False
assert current >= limit
def test_check_budget_default_limit(self, tracker: CostTracker) -> None:
"""Test budget check with default limit."""
within, current, limit = asyncio.run(
tracker.check_budget("proj-123")
)
assert limit == 1000.0 # Default from settings
def test_estimate_request_cost_known_model(self, tracker: CostTracker) -> None:
"""Test estimating cost for known model."""
cost = asyncio.run(
tracker.estimate_request_cost(
model="claude-opus-4",
prompt_tokens=1000,
max_completion_tokens=500,
)
)
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
assert cost == pytest.approx(0.0525, rel=0.01)
def test_estimate_request_cost_unknown_model(self, tracker: CostTracker) -> None:
"""Test estimating cost for unknown model."""
cost = asyncio.run(
tracker.estimate_request_cost(
model="unknown-model",
prompt_tokens=1000,
max_completion_tokens=500,
)
)
# Uses fallback estimate
assert cost > 0
def test_should_alert_below_threshold(self, tracker: CostTracker) -> None:
"""Test alert check when below threshold."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=50.0,
)
)
should_alert, current = asyncio.run(
tracker.should_alert("proj-123", threshold=100.0)
)
assert should_alert is False
assert current == pytest.approx(50.0, rel=0.01)
def test_should_alert_above_threshold(self, tracker: CostTracker) -> None:
"""Test alert check when above threshold."""
asyncio.run(
tracker.record_usage(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=1000,
completion_tokens=500,
cost_usd=150.0,
)
)
should_alert, current = asyncio.run(
tracker.should_alert("proj-123", threshold=100.0)
)
assert should_alert is True
def test_close(self, tracker: CostTracker) -> None:
"""Test closing tracker."""
asyncio.run(tracker.close())
assert tracker._redis is None
class TestGlobalTracker:
"""Tests for global tracker functions."""
def test_get_cost_tracker(self) -> None:
"""Test getting global tracker."""
reset_cost_tracker()
tracker = get_cost_tracker()
assert isinstance(tracker, CostTracker)
def test_get_cost_tracker_singleton(self) -> None:
"""Test tracker is singleton."""
reset_cost_tracker()
tracker1 = get_cost_tracker()
tracker2 = get_cost_tracker()
assert tracker1 is tracker2
def test_reset_cost_tracker(self) -> None:
"""Test resetting global tracker."""
reset_cost_tracker()
tracker1 = get_cost_tracker()
reset_cost_tracker()
tracker2 = get_cost_tracker()
assert tracker1 is not tracker2
def test_close_cost_tracker(self) -> None:
"""Test closing global tracker."""
reset_cost_tracker()
_ = get_cost_tracker()
asyncio.run(close_cost_tracker())
# Getting again should create a new one
tracker2 = get_cost_tracker()
assert tracker2 is not None

View File

@@ -0,0 +1,377 @@
"""
Tests for exceptions module.
"""
from exceptions import (
AllProvidersFailedError,
CircuitOpenError,
ConfigurationError,
ContextTooLongError,
CostLimitExceededError,
ErrorCode,
InvalidModelError,
InvalidModelGroupError,
LLMGatewayError,
ModelNotAvailableError,
ProviderError,
RateLimitError,
StreamError,
TokenLimitExceededError,
)
class TestErrorCode:
"""Tests for ErrorCode enum."""
def test_error_code_values(self) -> None:
"""Test error code values."""
assert ErrorCode.UNKNOWN_ERROR.value == "LLM_UNKNOWN_ERROR"
assert ErrorCode.PROVIDER_ERROR.value == "LLM_PROVIDER_ERROR"
assert ErrorCode.CIRCUIT_OPEN.value == "LLM_CIRCUIT_OPEN"
assert ErrorCode.COST_LIMIT_EXCEEDED.value == "LLM_COST_LIMIT_EXCEEDED"
class TestLLMGatewayError:
"""Tests for LLMGatewayError base class."""
def test_basic_error(self) -> None:
"""Test basic error creation."""
error = LLMGatewayError("Something went wrong")
assert str(error) == "[LLM_UNKNOWN_ERROR] Something went wrong"
assert error.message == "Something went wrong"
assert error.code == ErrorCode.UNKNOWN_ERROR
assert error.details == {}
assert error.cause is None
def test_error_with_code(self) -> None:
"""Test error with custom code."""
error = LLMGatewayError(
"Provider failed",
code=ErrorCode.PROVIDER_ERROR,
)
assert error.code == ErrorCode.PROVIDER_ERROR
def test_error_with_details(self) -> None:
"""Test error with details."""
error = LLMGatewayError(
"Error",
details={"key": "value"},
)
assert error.details == {"key": "value"}
def test_error_with_cause(self) -> None:
"""Test error with cause exception."""
cause = ValueError("Original error")
error = LLMGatewayError("Wrapped error", cause=cause)
assert error.cause is cause
def test_to_dict(self) -> None:
"""Test converting error to dict."""
error = LLMGatewayError(
"Test error",
code=ErrorCode.INVALID_REQUEST,
details={"field": "value"},
)
result = error.to_dict()
assert result["error"] == "LLM_INVALID_REQUEST"
assert result["message"] == "Test error"
assert result["details"] == {"field": "value"}
def test_to_dict_no_details(self) -> None:
"""Test to_dict without details."""
error = LLMGatewayError("Test error")
result = error.to_dict()
assert "details" not in result
def test_repr(self) -> None:
"""Test error repr."""
error = LLMGatewayError("Test", details={"key": "val"})
repr_str = repr(error)
assert "LLMGatewayError" in repr_str
assert "Test" in repr_str
class TestProviderError:
"""Tests for ProviderError."""
def test_basic_provider_error(self) -> None:
"""Test basic provider error."""
error = ProviderError(
message="API call failed",
provider="anthropic",
)
assert error.provider == "anthropic"
assert error.model is None
assert error.status_code is None
assert error.code == ErrorCode.PROVIDER_ERROR
assert "provider" in error.details
def test_provider_error_with_model(self) -> None:
"""Test provider error with model info."""
error = ProviderError(
message="Model not found",
provider="openai",
model="gpt-5",
status_code=404,
)
assert error.model == "gpt-5"
assert error.status_code == 404
assert error.details["model"] == "gpt-5"
assert error.details["status_code"] == 404
def test_provider_error_with_cause(self) -> None:
"""Test provider error with cause."""
cause = ConnectionError("Network down")
error = ProviderError(
message="Connection failed",
provider="google",
cause=cause,
)
assert error.cause is cause
class TestRateLimitError:
"""Tests for RateLimitError."""
def test_internal_rate_limit(self) -> None:
"""Test internal rate limit error."""
error = RateLimitError(
message="Too many requests",
retry_after=60,
)
assert error.code == ErrorCode.RATE_LIMIT_EXCEEDED
assert error.provider is None
assert error.retry_after == 60
assert error.details["retry_after_seconds"] == 60
def test_provider_rate_limit(self) -> None:
"""Test provider rate limit error."""
error = RateLimitError(
message="OpenAI rate limit",
provider="openai",
retry_after=30,
)
assert error.code == ErrorCode.PROVIDER_RATE_LIMIT
assert error.provider == "openai"
assert error.details["provider"] == "openai"
class TestCircuitOpenError:
"""Tests for CircuitOpenError."""
def test_circuit_open_error(self) -> None:
"""Test circuit open error."""
error = CircuitOpenError(
provider="anthropic",
recovery_time=45,
)
assert error.provider == "anthropic"
assert error.recovery_time == 45
assert error.code == ErrorCode.CIRCUIT_OPEN
assert "Circuit breaker open" in error.message
assert error.details["recovery_time_seconds"] == 45
def test_circuit_open_no_recovery_time(self) -> None:
"""Test circuit open without recovery time."""
error = CircuitOpenError(provider="openai")
assert error.recovery_time is None
assert "recovery_time_seconds" not in error.details
class TestCostLimitExceededError:
"""Tests for CostLimitExceededError."""
def test_project_cost_limit(self) -> None:
"""Test project cost limit error."""
error = CostLimitExceededError(
entity_type="project",
entity_id="proj-123",
current_cost=150.0,
limit=100.0,
)
assert error.entity_type == "project"
assert error.entity_id == "proj-123"
assert error.current_cost == 150.0
assert error.limit == 100.0
assert error.code == ErrorCode.COST_LIMIT_EXCEEDED
assert "$150.00" in error.message
assert "$100.00" in error.message
def test_agent_cost_limit(self) -> None:
"""Test agent cost limit error."""
error = CostLimitExceededError(
entity_type="agent",
entity_id="agent-456",
current_cost=50.0,
limit=25.0,
)
assert error.entity_type == "agent"
assert error.details["entity_type"] == "agent"
class TestInvalidModelGroupError:
"""Tests for InvalidModelGroupError."""
def test_invalid_group_error(self) -> None:
"""Test invalid model group error."""
error = InvalidModelGroupError(
model_group="invalid_group",
available_groups=["reasoning", "code", "fast"],
)
assert error.model_group == "invalid_group"
assert error.available_groups == ["reasoning", "code", "fast"]
assert error.code == ErrorCode.INVALID_MODEL_GROUP
assert "invalid_group" in error.message
def test_invalid_group_no_available(self) -> None:
"""Test invalid group without available list."""
error = InvalidModelGroupError(model_group="unknown")
assert error.available_groups is None
assert "available_groups" not in error.details
class TestInvalidModelError:
"""Tests for InvalidModelError."""
def test_invalid_model_error(self) -> None:
"""Test invalid model error."""
error = InvalidModelError(
model="gpt-99",
reason="Model does not exist",
)
assert error.model == "gpt-99"
assert error.code == ErrorCode.INVALID_MODEL
assert "gpt-99" in error.message
assert "Model does not exist" in error.message
def test_invalid_model_no_reason(self) -> None:
"""Test invalid model without reason."""
error = InvalidModelError(model="unknown-model")
assert "reason" not in error.details
class TestModelNotAvailableError:
"""Tests for ModelNotAvailableError."""
def test_model_not_available(self) -> None:
"""Test model not available error."""
error = ModelNotAvailableError(
model="claude-opus-4",
provider="anthropic",
)
assert error.model == "claude-opus-4"
assert error.provider == "anthropic"
assert error.code == ErrorCode.MODEL_NOT_AVAILABLE
assert "not configured" in error.message
class TestAllProvidersFailedError:
"""Tests for AllProvidersFailedError."""
def test_all_providers_failed(self) -> None:
"""Test all providers failed error."""
errors = [
{"model": "claude-opus-4", "error": "Rate limited"},
{"model": "gpt-4.1", "error": "Timeout"},
]
error = AllProvidersFailedError(
model_group="reasoning",
attempted_models=["claude-opus-4", "gpt-4.1"],
errors=errors,
)
assert error.model_group == "reasoning"
assert error.attempted_models == ["claude-opus-4", "gpt-4.1"]
assert error.errors == errors
assert error.code == ErrorCode.ALL_PROVIDERS_FAILED
class TestStreamError:
"""Tests for StreamError."""
def test_stream_error(self) -> None:
"""Test stream error."""
cause = OSError("Connection reset")
error = StreamError(
message="Stream interrupted",
chunks_received=10,
cause=cause,
)
assert error.chunks_received == 10
assert error.cause is cause
assert error.code == ErrorCode.STREAM_ERROR
class TestTokenLimitExceededError:
"""Tests for TokenLimitExceededError."""
def test_token_limit_exceeded(self) -> None:
"""Test token limit exceeded error."""
error = TokenLimitExceededError(
model="claude-haiku",
token_count=10000,
limit=8192,
)
assert error.model == "claude-haiku"
assert error.token_count == 10000
assert error.limit == 8192
assert error.code == ErrorCode.TOKEN_LIMIT_EXCEEDED
class TestContextTooLongError:
"""Tests for ContextTooLongError."""
def test_context_too_long(self) -> None:
"""Test context too long error."""
error = ContextTooLongError(
model="gpt-4.1-mini",
context_length=150000,
max_context=100000,
)
assert error.model == "gpt-4.1-mini"
assert error.context_length == 150000
assert error.max_context == 100000
assert error.code == ErrorCode.CONTEXT_TOO_LONG
class TestConfigurationError:
"""Tests for ConfigurationError."""
def test_configuration_error(self) -> None:
"""Test configuration error."""
error = ConfigurationError(
message="Missing API key",
config_key="ANTHROPIC_API_KEY",
)
assert error.config_key == "ANTHROPIC_API_KEY"
assert error.code == ErrorCode.CONFIGURATION_ERROR
assert error.details["config_key"] == "ANTHROPIC_API_KEY"
def test_configuration_error_no_key(self) -> None:
"""Test configuration error without key."""
error = ConfigurationError(message="Invalid configuration")
assert error.config_key is None
assert "config_key" not in error.details

View File

@@ -0,0 +1,407 @@
"""
Tests for failover module (circuit breaker).
"""
import asyncio
import time
import pytest
from config import Settings
from exceptions import CircuitOpenError
from failover import (
CircuitBreaker,
CircuitBreakerRegistry,
CircuitState,
CircuitStats,
get_circuit_registry,
reset_circuit_registry,
)
class TestCircuitState:
"""Tests for CircuitState enum."""
def test_circuit_states(self) -> None:
"""Test circuit state values."""
assert CircuitState.CLOSED.value == "closed"
assert CircuitState.OPEN.value == "open"
assert CircuitState.HALF_OPEN.value == "half_open"
class TestCircuitStats:
"""Tests for CircuitStats dataclass."""
def test_default_stats(self) -> None:
"""Test default stats values."""
stats = CircuitStats()
assert stats.failures == 0
assert stats.successes == 0
assert stats.last_failure_time is None
assert stats.last_success_time is None
assert stats.half_open_calls == 0
class TestCircuitBreaker:
"""Tests for CircuitBreaker class."""
def test_initial_state(self) -> None:
"""Test circuit breaker initial state."""
cb = CircuitBreaker(name="test", failure_threshold=5)
assert cb.name == "test"
assert cb.state == CircuitState.CLOSED
assert cb.failure_threshold == 5
assert cb.is_available() is True
def test_state_remains_closed_below_threshold(self) -> None:
"""Test circuit stays closed below failure threshold."""
cb = CircuitBreaker(name="test", failure_threshold=3)
# Record 2 failures (below threshold)
asyncio.run(cb.record_failure())
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.CLOSED
assert cb.stats.failures == 2
assert cb.is_available() is True
def test_state_opens_at_threshold(self) -> None:
"""Test circuit opens at failure threshold."""
cb = CircuitBreaker(name="test", failure_threshold=3)
# Record 3 failures (at threshold)
asyncio.run(cb.record_failure())
asyncio.run(cb.record_failure())
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
assert cb.is_available() is False
def test_success_resets_in_closed(self) -> None:
"""Test success in closed state records properly."""
cb = CircuitBreaker(name="test", failure_threshold=3)
asyncio.run(cb.record_failure())
asyncio.run(cb.record_success())
assert cb.state == CircuitState.CLOSED
assert cb.stats.successes == 1
assert cb.stats.last_success_time is not None
def test_half_open_transition(self) -> None:
"""Test transition to half-open after recovery timeout."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=1, # 1 second
)
# Open the circuit
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
# Wait for recovery timeout
time.sleep(1.1)
# State should transition to half-open
assert cb.state == CircuitState.HALF_OPEN
assert cb.is_available() is True
def test_half_open_success_closes(self) -> None:
"""Test success in half-open closes circuit."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=0, # Immediate recovery for testing
)
# Open and transition to half-open
asyncio.run(cb.record_failure())
time.sleep(0.1)
_ = cb.state # Trigger state check
assert cb.state == CircuitState.HALF_OPEN
# Success should close
asyncio.run(cb.record_success())
assert cb.state == CircuitState.CLOSED
def test_half_open_failure_reopens(self) -> None:
"""Test failure in half-open reopens circuit."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=0.05, # Small but non-zero for reliable timing
)
# Open and transition to half-open
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
# Wait for recovery timeout
time.sleep(0.1)
assert cb.state == CircuitState.HALF_OPEN
# Failure should reopen
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
def test_half_open_call_limit(self) -> None:
"""Test half-open call limit."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=0,
half_open_max_calls=2,
)
# Open and transition to half-open
asyncio.run(cb.record_failure())
time.sleep(0.1)
_ = cb.state
assert cb.is_available() is True
# Simulate calls in half-open
cb._stats.half_open_calls = 1
assert cb.is_available() is True
cb._stats.half_open_calls = 2
assert cb.is_available() is False
def test_time_until_recovery(self) -> None:
"""Test time until recovery calculation."""
cb = CircuitBreaker(
name="test",
failure_threshold=1,
recovery_timeout=60,
)
# Closed circuit has no recovery time
assert cb.time_until_recovery() is None
# Open circuit
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
# Should have recovery time
remaining = cb.time_until_recovery()
assert remaining is not None
assert 0 <= remaining <= 60
def test_execute_success(self) -> None:
"""Test execute with successful function."""
cb = CircuitBreaker(name="test", failure_threshold=3)
async def success_func() -> str:
return "success"
result = asyncio.run(cb.execute(success_func))
assert result == "success"
assert cb.stats.successes == 1
def test_execute_failure(self) -> None:
"""Test execute with failing function."""
cb = CircuitBreaker(name="test", failure_threshold=3)
async def fail_func() -> None:
raise ValueError("Error")
with pytest.raises(ValueError):
asyncio.run(cb.execute(fail_func))
assert cb.stats.failures == 1
def test_execute_when_open(self) -> None:
"""Test execute raises when circuit is open."""
cb = CircuitBreaker(name="test", failure_threshold=1)
# Open the circuit
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
async def success_func() -> str:
return "success"
with pytest.raises(CircuitOpenError) as exc_info:
asyncio.run(cb.execute(success_func))
assert exc_info.value.provider == "test"
def test_reset(self) -> None:
"""Test circuit reset."""
cb = CircuitBreaker(name="test", failure_threshold=1)
# Open the circuit
asyncio.run(cb.record_failure())
assert cb.state == CircuitState.OPEN
# Reset
cb.reset()
assert cb.state == CircuitState.CLOSED
assert cb.stats.failures == 0
assert cb.stats.successes == 0
def test_to_dict(self) -> None:
"""Test converting circuit to dict."""
cb = CircuitBreaker(name="test", failure_threshold=3)
asyncio.run(cb.record_failure())
asyncio.run(cb.record_success())
result = cb.to_dict()
assert result["name"] == "test"
assert result["state"] == "closed"
assert result["failures"] == 1
assert result["successes"] == 1
assert result["is_available"] is True
class TestCircuitBreakerRegistry:
"""Tests for CircuitBreakerRegistry class."""
def test_get_circuit_creates_new(self) -> None:
"""Test getting a new circuit."""
settings = Settings(circuit_failure_threshold=5)
registry = CircuitBreakerRegistry(settings=settings)
circuit = asyncio.run(registry.get_circuit("anthropic"))
assert circuit.name == "anthropic"
assert circuit.failure_threshold == 5
def test_get_circuit_returns_same(self) -> None:
"""Test getting same circuit twice."""
registry = CircuitBreakerRegistry()
circuit1 = asyncio.run(registry.get_circuit("openai"))
circuit2 = asyncio.run(registry.get_circuit("openai"))
assert circuit1 is circuit2
def test_get_circuit_sync(self) -> None:
"""Test sync circuit getter."""
registry = CircuitBreakerRegistry()
circuit = registry.get_circuit_sync("google")
assert circuit.name == "google"
def test_is_available(self) -> None:
"""Test checking if circuit is available."""
registry = CircuitBreakerRegistry()
assert asyncio.run(registry.is_available("test")) is True
# Open the circuit
circuit = asyncio.run(registry.get_circuit("test"))
for _ in range(5):
asyncio.run(circuit.record_failure())
assert asyncio.run(registry.is_available("test")) is False
def test_record_success(self) -> None:
"""Test recording success through registry."""
registry = CircuitBreakerRegistry()
asyncio.run(registry.record_success("test"))
circuit = asyncio.run(registry.get_circuit("test"))
assert circuit.stats.successes == 1
def test_record_failure(self) -> None:
"""Test recording failure through registry."""
registry = CircuitBreakerRegistry()
asyncio.run(registry.record_failure("test"))
circuit = asyncio.run(registry.get_circuit("test"))
assert circuit.stats.failures == 1
def test_reset(self) -> None:
"""Test resetting a specific circuit."""
registry = CircuitBreakerRegistry()
# Create and fail a circuit
asyncio.run(registry.record_failure("test"))
asyncio.run(registry.reset("test"))
circuit = asyncio.run(registry.get_circuit("test"))
assert circuit.stats.failures == 0
def test_reset_all(self) -> None:
"""Test resetting all circuits."""
registry = CircuitBreakerRegistry()
# Create multiple circuits with failures
asyncio.run(registry.record_failure("circuit1"))
asyncio.run(registry.record_failure("circuit2"))
asyncio.run(registry.reset_all())
circuit1 = asyncio.run(registry.get_circuit("circuit1"))
circuit2 = asyncio.run(registry.get_circuit("circuit2"))
assert circuit1.stats.failures == 0
assert circuit2.stats.failures == 0
def test_get_all_states(self) -> None:
"""Test getting all circuit states."""
registry = CircuitBreakerRegistry()
asyncio.run(registry.get_circuit("circuit1"))
asyncio.run(registry.get_circuit("circuit2"))
states = registry.get_all_states()
assert "circuit1" in states
assert "circuit2" in states
assert states["circuit1"]["state"] == "closed"
def test_get_open_circuits(self) -> None:
"""Test getting open circuits."""
settings = Settings(circuit_failure_threshold=1)
registry = CircuitBreakerRegistry(settings=settings)
asyncio.run(registry.get_circuit("healthy"))
asyncio.run(registry.record_failure("failing"))
open_circuits = registry.get_open_circuits()
assert "failing" in open_circuits
assert "healthy" not in open_circuits
def test_get_available_circuits(self) -> None:
"""Test getting available circuits."""
settings = Settings(circuit_failure_threshold=1)
registry = CircuitBreakerRegistry(settings=settings)
asyncio.run(registry.get_circuit("healthy"))
asyncio.run(registry.record_failure("failing"))
available = registry.get_available_circuits()
assert "healthy" in available
assert "failing" not in available
class TestGlobalRegistry:
"""Tests for global registry functions."""
def test_get_circuit_registry(self) -> None:
"""Test getting global registry."""
reset_circuit_registry()
registry = get_circuit_registry()
assert isinstance(registry, CircuitBreakerRegistry)
def test_get_circuit_registry_singleton(self) -> None:
"""Test registry is singleton."""
reset_circuit_registry()
registry1 = get_circuit_registry()
registry2 = get_circuit_registry()
assert registry1 is registry2
def test_reset_circuit_registry(self) -> None:
"""Test resetting global registry."""
reset_circuit_registry()
registry1 = get_circuit_registry()
reset_circuit_registry()
registry2 = get_circuit_registry()
assert registry1 is not registry2

View File

@@ -0,0 +1,408 @@
"""
Tests for models module.
"""
from datetime import UTC, datetime
import pytest
from models import (
AGENT_TYPE_MODEL_PREFERENCES,
MODEL_CONFIGS,
MODEL_GROUPS,
ChatMessage,
CompletionRequest,
CompletionResponse,
CostRecord,
EmbeddingRequest,
ModelConfig,
ModelGroup,
ModelGroupConfig,
ModelGroupInfo,
ModelInfo,
Provider,
StreamChunk,
UsageReport,
UsageStats,
)
class TestModelGroup:
"""Tests for ModelGroup enum."""
def test_model_group_values(self) -> None:
"""Test model group enum values."""
assert ModelGroup.REASONING.value == "reasoning"
assert ModelGroup.CODE.value == "code"
assert ModelGroup.FAST.value == "fast"
assert ModelGroup.VISION.value == "vision"
assert ModelGroup.EMBEDDING.value == "embedding"
assert ModelGroup.COST_OPTIMIZED.value == "cost_optimized"
assert ModelGroup.SELF_HOSTED.value == "self_hosted"
def test_model_group_from_string(self) -> None:
"""Test creating ModelGroup from string."""
assert ModelGroup("reasoning") == ModelGroup.REASONING
assert ModelGroup("code") == ModelGroup.CODE
assert ModelGroup("fast") == ModelGroup.FAST
def test_model_group_invalid(self) -> None:
"""Test invalid model group value."""
with pytest.raises(ValueError):
ModelGroup("invalid_group")
class TestProvider:
"""Tests for Provider enum."""
def test_provider_values(self) -> None:
"""Test provider enum values."""
assert Provider.ANTHROPIC.value == "anthropic"
assert Provider.OPENAI.value == "openai"
assert Provider.GOOGLE.value == "google"
assert Provider.ALIBABA.value == "alibaba"
assert Provider.DEEPSEEK.value == "deepseek"
class TestModelConfig:
"""Tests for ModelConfig dataclass."""
def test_model_config_creation(self) -> None:
"""Test creating a ModelConfig."""
config = ModelConfig(
name="test-model",
litellm_name="provider/test-model",
provider=Provider.ANTHROPIC,
cost_per_1m_input=10.0,
cost_per_1m_output=30.0,
context_window=100000,
max_output_tokens=4096,
supports_vision=True,
)
assert config.name == "test-model"
assert config.provider == Provider.ANTHROPIC
assert config.cost_per_1m_input == 10.0
assert config.supports_vision is True
assert config.supports_streaming is True # default
def test_model_configs_exist(self) -> None:
"""Test that model configs are defined."""
assert len(MODEL_CONFIGS) > 0
assert "claude-opus-4" in MODEL_CONFIGS
assert "gpt-4.1" in MODEL_CONFIGS
assert "gemini-2.5-pro" in MODEL_CONFIGS
class TestModelGroupConfig:
"""Tests for ModelGroupConfig dataclass."""
def test_model_group_config_creation(self) -> None:
"""Test creating a ModelGroupConfig."""
config = ModelGroupConfig(
primary="model-a",
fallbacks=["model-b", "model-c"],
description="Test group",
)
assert config.primary == "model-a"
assert config.fallbacks == ["model-b", "model-c"]
assert config.description == "Test group"
def test_get_all_models(self) -> None:
"""Test getting all models in order."""
config = ModelGroupConfig(
primary="model-a",
fallbacks=["model-b", "model-c"],
description="Test group",
)
models = config.get_all_models()
assert models == ["model-a", "model-b", "model-c"]
def test_model_groups_exist(self) -> None:
"""Test that model groups are defined."""
assert len(MODEL_GROUPS) > 0
assert ModelGroup.REASONING in MODEL_GROUPS
assert ModelGroup.CODE in MODEL_GROUPS
assert ModelGroup.FAST in MODEL_GROUPS
class TestAgentTypePreferences:
"""Tests for agent type model preferences."""
def test_agent_preferences_exist(self) -> None:
"""Test that agent preferences are defined."""
assert len(AGENT_TYPE_MODEL_PREFERENCES) > 0
assert "product_owner" in AGENT_TYPE_MODEL_PREFERENCES
assert "software_engineer" in AGENT_TYPE_MODEL_PREFERENCES
def test_agent_preference_values(self) -> None:
"""Test agent preference values."""
assert AGENT_TYPE_MODEL_PREFERENCES["product_owner"] == ModelGroup.REASONING
assert AGENT_TYPE_MODEL_PREFERENCES["software_engineer"] == ModelGroup.CODE
assert AGENT_TYPE_MODEL_PREFERENCES["devops_engineer"] == ModelGroup.FAST
class TestChatMessage:
"""Tests for ChatMessage model."""
def test_chat_message_creation(self) -> None:
"""Test creating a ChatMessage."""
msg = ChatMessage(role="user", content="Hello")
assert msg.role == "user"
assert msg.content == "Hello"
assert msg.name is None
assert msg.tool_calls is None
def test_chat_message_with_optional(self) -> None:
"""Test ChatMessage with optional fields."""
msg = ChatMessage(
role="assistant",
content="Response",
name="assistant_1",
tool_calls=[{"id": "call_1", "function": {"name": "test"}}],
)
assert msg.name == "assistant_1"
assert msg.tool_calls is not None
def test_chat_message_list_content(self) -> None:
"""Test ChatMessage with list content (for images)."""
msg = ChatMessage(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/img.jpg"}},
],
)
assert isinstance(msg.content, list)
class TestCompletionRequest:
"""Tests for CompletionRequest model."""
def test_completion_request_minimal(self) -> None:
"""Test minimal CompletionRequest."""
req = CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
)
assert req.project_id == "proj-123"
assert req.agent_id == "agent-456"
assert len(req.messages) == 1
assert req.model_group == ModelGroup.REASONING # default
assert req.max_tokens == 4096 # default
assert req.temperature == 0.7 # default
def test_completion_request_full(self) -> None:
"""Test full CompletionRequest."""
req = CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
model_group=ModelGroup.CODE,
model_override="claude-sonnet-4",
max_tokens=8192,
temperature=0.5,
stream=True,
session_id="session-789",
metadata={"key": "value"},
)
assert req.model_group == ModelGroup.CODE
assert req.model_override == "claude-sonnet-4"
assert req.max_tokens == 8192
assert req.stream is True
def test_completion_request_validation(self) -> None:
"""Test CompletionRequest validation."""
with pytest.raises(ValueError):
CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
max_tokens=0, # Invalid
)
with pytest.raises(ValueError):
CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
temperature=-0.1, # Invalid
)
class TestUsageStats:
"""Tests for UsageStats model."""
def test_usage_stats_default(self) -> None:
"""Test default UsageStats."""
stats = UsageStats()
assert stats.prompt_tokens == 0
assert stats.completion_tokens == 0
assert stats.total_tokens == 0
assert stats.cost_usd == 0.0
def test_usage_stats_custom(self) -> None:
"""Test custom UsageStats."""
stats = UsageStats(
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
cost_usd=0.001,
)
assert stats.prompt_tokens == 100
assert stats.total_tokens == 150
def test_usage_stats_from_response(self) -> None:
"""Test creating UsageStats from response."""
config = MODEL_CONFIGS["claude-opus-4"]
stats = UsageStats.from_response(
prompt_tokens=1000,
completion_tokens=500,
model_config=config,
)
assert stats.prompt_tokens == 1000
assert stats.completion_tokens == 500
assert stats.total_tokens == 1500
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
assert stats.cost_usd == pytest.approx(0.0525, rel=0.01)
class TestCompletionResponse:
"""Tests for CompletionResponse model."""
def test_completion_response_creation(self) -> None:
"""Test creating a CompletionResponse."""
response = CompletionResponse(
id="resp-123",
model="claude-opus-4",
provider="anthropic",
content="Hello, world!",
)
assert response.id == "resp-123"
assert response.model == "claude-opus-4"
assert response.provider == "anthropic"
assert response.content == "Hello, world!"
assert response.finish_reason == "stop"
class TestStreamChunk:
"""Tests for StreamChunk model."""
def test_stream_chunk_creation(self) -> None:
"""Test creating a StreamChunk."""
chunk = StreamChunk(id="chunk-1", delta="Hello")
assert chunk.id == "chunk-1"
assert chunk.delta == "Hello"
assert chunk.finish_reason is None
def test_stream_chunk_final(self) -> None:
"""Test final StreamChunk."""
chunk = StreamChunk(
id="chunk-last",
delta="",
finish_reason="stop",
usage=UsageStats(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
assert chunk.finish_reason == "stop"
assert chunk.usage is not None
class TestEmbeddingRequest:
"""Tests for EmbeddingRequest model."""
def test_embedding_request_creation(self) -> None:
"""Test creating an EmbeddingRequest."""
req = EmbeddingRequest(
project_id="proj-123",
agent_id="agent-456",
texts=["Hello", "World"],
)
assert req.project_id == "proj-123"
assert len(req.texts) == 2
assert req.model == "text-embedding-3-large" # default
def test_embedding_request_validation(self) -> None:
"""Test EmbeddingRequest validation."""
with pytest.raises(ValueError):
EmbeddingRequest(
project_id="proj-123",
agent_id="agent-456",
texts=[], # Invalid - must have at least 1
)
class TestCostRecord:
"""Tests for CostRecord dataclass."""
def test_cost_record_creation(self) -> None:
"""Test creating a CostRecord."""
record = CostRecord(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
assert record.project_id == "proj-123"
assert record.cost_usd == 0.01
assert record.timestamp is not None
class TestUsageReport:
"""Tests for UsageReport model."""
def test_usage_report_creation(self) -> None:
"""Test creating a UsageReport."""
now = datetime.now(UTC)
report = UsageReport(
entity_id="proj-123",
entity_type="project",
period="day",
period_start=now,
period_end=now,
)
assert report.entity_id == "proj-123"
assert report.entity_type == "project"
assert report.total_requests == 0
assert report.total_cost_usd == 0.0
class TestModelInfo:
"""Tests for ModelInfo model."""
def test_model_info_from_config(self) -> None:
"""Test creating ModelInfo from ModelConfig."""
config = MODEL_CONFIGS["claude-opus-4"]
info = ModelInfo.from_config(config, available=True)
assert info.name == "claude-opus-4"
assert info.provider == "anthropic"
assert info.available is True
assert info.supports_vision is True
class TestModelGroupInfo:
"""Tests for ModelGroupInfo model."""
def test_model_group_info_creation(self) -> None:
"""Test creating ModelGroupInfo."""
info = ModelGroupInfo(
name="reasoning",
description="Complex analysis",
primary_model="claude-opus-4",
fallback_models=["gpt-4.1"],
)
assert info.name == "reasoning"
assert len(info.fallback_models) == 1

View File

@@ -0,0 +1,308 @@
"""
Tests for providers module.
"""
import os
from unittest.mock import patch
import pytest
from config import Settings
from models import MODEL_CONFIGS, ModelGroup, Provider
from providers import (
LLMProvider,
build_fallback_config,
build_model_list,
configure_litellm,
get_available_model_groups,
get_available_models,
get_provider,
reset_provider,
)
@pytest.fixture
def full_settings() -> Settings:
"""Settings with all providers configured."""
return Settings(
anthropic_api_key="test-anthropic-key",
openai_api_key="test-openai-key",
google_api_key="test-google-key",
alibaba_api_key="test-alibaba-key",
deepseek_api_key="test-deepseek-key",
litellm_timeout=60,
litellm_cache_enabled=False,
)
@pytest.fixture
def partial_settings() -> Settings:
"""Settings with only some providers configured."""
return Settings(
anthropic_api_key="test-anthropic-key",
openai_api_key=None,
google_api_key=None,
alibaba_api_key=None,
deepseek_api_key=None,
)
@pytest.fixture
def empty_settings() -> Settings:
"""Settings with no providers configured."""
return Settings(
anthropic_api_key=None,
openai_api_key=None,
google_api_key=None,
alibaba_api_key=None,
deepseek_api_key=None,
)
class TestConfigureLiteLLM:
"""Tests for configure_litellm function."""
def test_sets_api_keys(self, full_settings: Settings) -> None:
"""Test that API keys are set in environment."""
with patch.dict(os.environ, {}, clear=True):
configure_litellm(full_settings)
assert os.environ.get("ANTHROPIC_API_KEY") == "test-anthropic-key"
assert os.environ.get("OPENAI_API_KEY") == "test-openai-key"
assert os.environ.get("GEMINI_API_KEY") == "test-google-key"
def test_skips_none_keys(self, partial_settings: Settings) -> None:
"""Test that None keys are not set."""
with patch.dict(os.environ, {}, clear=True):
configure_litellm(partial_settings)
assert os.environ.get("ANTHROPIC_API_KEY") == "test-anthropic-key"
assert "OPENAI_API_KEY" not in os.environ
class TestBuildModelList:
"""Tests for build_model_list function."""
def test_build_with_all_providers(self, full_settings: Settings) -> None:
"""Test building model list with all providers."""
model_list = build_model_list(full_settings)
assert len(model_list) > 0
# Check structure
for entry in model_list:
assert "model_name" in entry
assert "litellm_params" in entry
assert "model" in entry["litellm_params"]
assert "timeout" in entry["litellm_params"]
def test_build_with_partial_providers(self, partial_settings: Settings) -> None:
"""Test building model list with partial providers."""
model_list = build_model_list(partial_settings)
# Should only include Anthropic models
providers = set()
for entry in model_list:
model_name = entry["model_name"]
config = MODEL_CONFIGS.get(model_name)
if config:
providers.add(config.provider)
assert Provider.ANTHROPIC in providers
assert Provider.OPENAI not in providers
def test_build_with_no_providers(self, empty_settings: Settings) -> None:
"""Test building model list with no providers."""
model_list = build_model_list(empty_settings)
assert len(model_list) == 0
def test_build_includes_timeout(self, full_settings: Settings) -> None:
"""Test that model entries include timeout."""
model_list = build_model_list(full_settings)
for entry in model_list:
assert entry["litellm_params"]["timeout"] == 60
class TestBuildFallbackConfig:
"""Tests for build_fallback_config function."""
def test_build_fallbacks_full(self, full_settings: Settings) -> None:
"""Test building fallback config with all providers."""
fallbacks = build_fallback_config(full_settings)
assert len(fallbacks) > 0
# Primary models should have fallbacks
for _primary, chain in fallbacks.items():
assert isinstance(chain, list)
assert len(chain) > 0
def test_build_fallbacks_partial(self, partial_settings: Settings) -> None:
"""Test building fallback config with partial providers."""
fallbacks = build_fallback_config(partial_settings)
# With only Anthropic, there should be no fallbacks
# (fallbacks require at least 2 available models)
for primary, chain in fallbacks.items():
# All models in chain should be from Anthropic
for model in [primary] + chain:
config = MODEL_CONFIGS.get(model)
if config:
assert config.provider == Provider.ANTHROPIC
class TestGetAvailableModels:
"""Tests for get_available_models function."""
def test_get_available_full(self, full_settings: Settings) -> None:
"""Test getting available models with all providers."""
models = get_available_models(full_settings)
assert len(models) > 0
assert "claude-opus-4" in models
assert "gpt-4.1" in models
def test_get_available_partial(self, partial_settings: Settings) -> None:
"""Test getting available models with partial providers."""
models = get_available_models(partial_settings)
assert "claude-opus-4" in models
assert "gpt-4.1" not in models
def test_get_available_empty(self, empty_settings: Settings) -> None:
"""Test getting available models with no providers."""
models = get_available_models(empty_settings)
assert len(models) == 0
class TestGetAvailableModelGroups:
"""Tests for get_available_model_groups function."""
def test_get_groups_full(self, full_settings: Settings) -> None:
"""Test getting groups with all providers."""
groups = get_available_model_groups(full_settings)
assert len(groups) == len(ModelGroup)
assert ModelGroup.REASONING in groups
assert len(groups[ModelGroup.REASONING]) > 0
def test_get_groups_partial(self, partial_settings: Settings) -> None:
"""Test getting groups with partial providers."""
groups = get_available_model_groups(partial_settings)
# Only Anthropic models should be available
for _group, models in groups.items():
for model in models:
config = MODEL_CONFIGS.get(model)
if config:
assert config.provider == Provider.ANTHROPIC
class TestLLMProvider:
"""Tests for LLMProvider class."""
def test_initialization(self, full_settings: Settings) -> None:
"""Test provider initialization."""
provider = LLMProvider(settings=full_settings)
assert provider._initialized is False
assert provider._router is None
def test_initialize(self, full_settings: Settings) -> None:
"""Test provider initialize."""
with patch("providers.Router") as mock_router:
provider = LLMProvider(settings=full_settings)
provider.initialize()
assert provider._initialized is True
mock_router.assert_called_once()
def test_initialize_idempotent(self, full_settings: Settings) -> None:
"""Test that initialize is idempotent."""
with patch("providers.Router") as mock_router:
provider = LLMProvider(settings=full_settings)
provider.initialize()
provider.initialize()
# Should only be called once
assert mock_router.call_count == 1
def test_initialize_no_providers(self, empty_settings: Settings) -> None:
"""Test initialization with no providers."""
provider = LLMProvider(settings=empty_settings)
provider.initialize()
assert provider._initialized is True
assert provider._router is None
def test_router_property(self, full_settings: Settings) -> None:
"""Test router property triggers initialization."""
with patch("providers.Router"):
provider = LLMProvider(settings=full_settings)
_ = provider.router
assert provider._initialized is True
def test_is_available(self, full_settings: Settings) -> None:
"""Test is_available property."""
with patch("providers.Router"):
provider = LLMProvider(settings=full_settings)
assert provider.is_available is True
def test_is_not_available(self, empty_settings: Settings) -> None:
"""Test is_available when no providers."""
provider = LLMProvider(settings=empty_settings)
assert provider.is_available is False
def test_get_model_config(self, full_settings: Settings) -> None:
"""Test getting model config."""
provider = LLMProvider(settings=full_settings)
config = provider.get_model_config("claude-opus-4")
assert config is not None
assert config.name == "claude-opus-4"
assert provider.get_model_config("nonexistent") is None
def test_get_available_models(self, full_settings: Settings) -> None:
"""Test getting available models."""
provider = LLMProvider(settings=full_settings)
models = provider.get_available_models()
assert "claude-opus-4" in models
assert "gpt-4.1" in models
def test_is_model_available(self, full_settings: Settings) -> None:
"""Test checking model availability."""
provider = LLMProvider(settings=full_settings)
assert provider.is_model_available("claude-opus-4") is True
assert provider.is_model_available("nonexistent") is False
class TestGlobalProvider:
"""Tests for global provider functions."""
def test_get_provider(self) -> None:
"""Test getting global provider."""
reset_provider()
provider = get_provider()
assert isinstance(provider, LLMProvider)
def test_get_provider_singleton(self) -> None:
"""Test provider is singleton."""
reset_provider()
provider1 = get_provider()
provider2 = get_provider()
assert provider1 is provider2
def test_reset_provider(self) -> None:
"""Test resetting global provider."""
reset_provider()
provider1 = get_provider()
reset_provider()
provider2 = get_provider()
assert provider1 is not provider2

View File

@@ -0,0 +1,243 @@
"""
Tests for routing module.
"""
import asyncio
import pytest
from config import Settings
from exceptions import (
AllProvidersFailedError,
InvalidModelError,
InvalidModelGroupError,
ModelNotAvailableError,
)
from failover import CircuitBreakerRegistry, reset_circuit_registry
from models import ModelGroup
from providers import reset_provider
from routing import ModelRouter, get_model_router, reset_model_router
@pytest.fixture
def router_settings() -> Settings:
"""Settings for routing tests."""
return Settings(
anthropic_api_key="test-key",
openai_api_key="test-key",
google_api_key="test-key",
)
@pytest.fixture
def router(router_settings: Settings) -> ModelRouter:
"""Create model router for testing."""
reset_circuit_registry()
reset_model_router()
reset_provider()
registry = CircuitBreakerRegistry(settings=router_settings)
return ModelRouter(settings=router_settings, circuit_registry=registry)
class TestModelRouter:
"""Tests for ModelRouter class."""
def test_parse_model_group_valid(self, router: ModelRouter) -> None:
"""Test parsing valid model groups."""
assert router.parse_model_group("reasoning") == ModelGroup.REASONING
assert router.parse_model_group("code") == ModelGroup.CODE
assert router.parse_model_group("fast") == ModelGroup.FAST
assert router.parse_model_group("REASONING") == ModelGroup.REASONING
def test_parse_model_group_aliases(self, router: ModelRouter) -> None:
"""Test parsing model group aliases."""
assert router.parse_model_group("high-reasoning") == ModelGroup.REASONING
assert router.parse_model_group("high_reasoning") == ModelGroup.REASONING
assert router.parse_model_group("code-generation") == ModelGroup.CODE
assert router.parse_model_group("fast-response") == ModelGroup.FAST
def test_parse_model_group_invalid(self, router: ModelRouter) -> None:
"""Test parsing invalid model group."""
with pytest.raises(InvalidModelGroupError) as exc_info:
router.parse_model_group("invalid_group")
assert exc_info.value.model_group == "invalid_group"
assert exc_info.value.available_groups is not None
def test_get_model_config_valid(self, router: ModelRouter) -> None:
"""Test getting valid model config."""
config = router.get_model_config("claude-opus-4")
assert config.name == "claude-opus-4"
assert config.provider.value == "anthropic"
def test_get_model_config_invalid(self, router: ModelRouter) -> None:
"""Test getting invalid model config."""
with pytest.raises(InvalidModelError) as exc_info:
router.get_model_config("nonexistent-model")
assert exc_info.value.model == "nonexistent-model"
def test_get_preferred_group_for_agent(self, router: ModelRouter) -> None:
"""Test getting preferred group for agent types."""
assert router.get_preferred_group_for_agent("product_owner") == ModelGroup.REASONING
assert router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
assert router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
def test_get_preferred_group_unknown_agent(self, router: ModelRouter) -> None:
"""Test getting preferred group for unknown agent."""
# Should default to REASONING
assert router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
def test_select_model_by_group(self, router: ModelRouter) -> None:
"""Test selecting model by group."""
model_name, config = asyncio.run(
router.select_model(model_group=ModelGroup.REASONING)
)
assert model_name == "claude-opus-4"
assert config.provider.value == "anthropic"
def test_select_model_by_group_string(self, router: ModelRouter) -> None:
"""Test selecting model by group string."""
model_name, config = asyncio.run(
router.select_model(model_group="code")
)
assert model_name == "claude-sonnet-4"
def test_select_model_with_override(self, router: ModelRouter) -> None:
"""Test selecting specific model override."""
model_name, config = asyncio.run(
router.select_model(
model_group="reasoning",
model_override="gpt-4.1",
)
)
assert model_name == "gpt-4.1"
assert config.provider.value == "openai"
def test_select_model_override_invalid(self, router: ModelRouter) -> None:
"""Test selecting invalid model override."""
with pytest.raises(InvalidModelError):
asyncio.run(
router.select_model(
model_group="reasoning",
model_override="nonexistent-model",
)
)
def test_select_model_override_unavailable(self, router: ModelRouter) -> None: # noqa: ARG002
"""Test selecting unavailable model override."""
# Create router without Alibaba key
settings = Settings(
anthropic_api_key="test-key",
alibaba_api_key=None,
)
registry = CircuitBreakerRegistry(settings=settings)
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
with pytest.raises(ModelNotAvailableError):
asyncio.run(
limited_router.select_model(
model_group="reasoning",
model_override="qwen-max",
)
)
def test_select_model_fallback_on_circuit_open(
self,
router: ModelRouter,
) -> None:
"""Test fallback when primary circuit is open."""
# Open circuit for anthropic
circuit = router._circuit_registry.get_circuit_sync("anthropic")
for _ in range(5):
asyncio.run(circuit.record_failure())
# Should fall back to OpenAI
model_name, config = asyncio.run(
router.select_model(model_group=ModelGroup.REASONING)
)
assert model_name == "gpt-4.1"
assert config.provider.value == "openai"
def test_select_model_all_unavailable(self) -> None:
"""Test when all providers are unavailable."""
settings = Settings(
anthropic_api_key=None,
openai_api_key=None,
google_api_key=None,
)
registry = CircuitBreakerRegistry(settings=settings)
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
with pytest.raises(AllProvidersFailedError) as exc_info:
asyncio.run(
limited_router.select_model(model_group=ModelGroup.REASONING)
)
assert exc_info.value.model_group == "reasoning"
assert len(exc_info.value.attempted_models) > 0
def test_get_available_models_for_group(self, router: ModelRouter) -> None:
"""Test getting available models for a group."""
models = asyncio.run(
router.get_available_models_for_group(ModelGroup.REASONING)
)
assert len(models) > 0
# Should be (name, config, available) tuples
for name, config, _available in models:
assert isinstance(name, str)
assert config is not None
def test_get_available_models_for_group_string(self, router: ModelRouter) -> None:
"""Test getting available models with string group."""
models = asyncio.run(
router.get_available_models_for_group("code")
)
assert len(models) > 0
def test_get_available_models_invalid_group(self, router: ModelRouter) -> None:
"""Test getting models for invalid group."""
with pytest.raises(InvalidModelGroupError):
asyncio.run(
router.get_available_models_for_group("invalid")
)
def test_get_all_model_groups(self, router: ModelRouter) -> None:
"""Test getting all model groups info."""
groups = router.get_all_model_groups()
assert len(groups) == len(ModelGroup)
assert "reasoning" in groups
assert "code" in groups
assert groups["reasoning"]["primary"] == "claude-opus-4"
class TestGlobalRouter:
"""Tests for global router functions."""
def test_get_model_router(self) -> None:
"""Test getting global router."""
reset_model_router()
router = get_model_router()
assert isinstance(router, ModelRouter)
def test_get_model_router_singleton(self) -> None:
"""Test router is singleton."""
reset_model_router()
router1 = get_model_router()
router2 = get_model_router()
assert router1 is router2
def test_reset_model_router(self) -> None:
"""Test resetting global router."""
reset_model_router()
router1 = get_model_router()
reset_model_router()
router2 = get_model_router()
assert router1 is not router2

View File

@@ -0,0 +1,412 @@
"""
Tests for server module.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from config import Settings
from models import ModelGroup
@pytest.fixture
def test_settings() -> Settings:
"""Test settings with mock API keys."""
return Settings(
anthropic_api_key="test-anthropic-key",
openai_api_key="test-openai-key",
google_api_key="test-google-key",
cost_tracking_enabled=False, # Disable for most tests
litellm_cache_enabled=False,
)
@pytest.fixture
def test_client(test_settings: Settings) -> TestClient:
"""Create test client with mocked dependencies."""
with (
patch("server.get_settings", return_value=test_settings),
patch("server.get_provider") as mock_provider,
):
mock_provider.return_value = MagicMock()
mock_provider.return_value.is_available = True
mock_provider.return_value.router = MagicMock()
mock_provider.return_value.get_available_models.return_value = {}
from server import app
return TestClient(app)
class TestHealthEndpoint:
"""Tests for health check endpoint."""
def test_health_check(self, test_client: TestClient) -> None:
"""Test health check returns healthy status."""
with patch("server.get_settings") as mock_settings:
mock_settings.return_value = Settings(
anthropic_api_key="test-key",
)
with patch("server.get_provider") as mock_provider:
mock_provider.return_value = MagicMock()
mock_provider.return_value.is_available = True
response = test_client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["service"] == "llm-gateway"
class TestToolDiscoveryEndpoint:
"""Tests for tool discovery endpoint."""
def test_list_tools(self, test_client: TestClient) -> None:
"""Test listing available tools."""
response = test_client.get("/mcp/tools")
assert response.status_code == 200
data = response.json()
assert "tools" in data
assert len(data["tools"]) == 4 # 4 tools defined
tool_names = [t["name"] for t in data["tools"]]
assert "chat_completion" in tool_names
assert "list_models" in tool_names
assert "get_usage" in tool_names
assert "count_tokens" in tool_names
def test_tool_has_schema(self, test_client: TestClient) -> None:
"""Test that tools have input schemas."""
response = test_client.get("/mcp/tools")
data = response.json()
for tool in data["tools"]:
assert "inputSchema" in tool
assert "type" in tool["inputSchema"]
assert tool["inputSchema"]["type"] == "object"
class TestJSONRPCEndpoint:
"""Tests for JSON-RPC endpoint."""
def test_invalid_jsonrpc_version(self, test_client: TestClient) -> None:
"""Test invalid JSON-RPC version."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "1.0", # Invalid
"method": "tools/list",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert data["error"]["code"] == -32600
def test_tools_list(self, test_client: TestClient) -> None:
"""Test tools/list method."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/list",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
assert "tools" in data["result"]
def test_unknown_method(self, test_client: TestClient) -> None:
"""Test unknown method."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "unknown/method",
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert data["error"]["code"] == -32601
def test_unknown_tool(self, test_client: TestClient) -> None:
"""Test unknown tool."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "unknown_tool",
"arguments": {},
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "error" in data
assert "Unknown tool" in data["error"]["message"]
class TestCountTokensTool:
"""Tests for count_tokens tool."""
def test_count_tokens(self, test_client: TestClient) -> None:
"""Test counting tokens."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "count_tokens",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"text": "Hello, world!",
},
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
def test_count_tokens_with_model(self, test_client: TestClient) -> None:
"""Test counting tokens with specific model."""
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "count_tokens",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"text": "Hello, world!",
"model": "gpt-4",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestListModelsTool:
"""Tests for list_models tool."""
def test_list_all_models(self, test_client: TestClient) -> None:
"""Test listing all models."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.get_available_models_for_group = AsyncMock(
return_value=[]
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "list_models",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
},
},
"id": 1,
},
)
assert response.status_code == 200
def test_list_models_by_group(self, test_client: TestClient) -> None:
"""Test listing models by group."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.parse_model_group.return_value = ModelGroup.REASONING
mock_router.return_value.get_available_models_for_group = AsyncMock(
return_value=[]
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "list_models",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"model_group": "reasoning",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestGetUsageTool:
"""Tests for get_usage tool."""
def test_get_usage(self, test_client: TestClient) -> None:
"""Test getting usage."""
with patch("server.get_cost_tracker") as mock_tracker:
mock_report = MagicMock()
mock_report.total_requests = 10
mock_report.total_tokens = 1000
mock_report.total_cost_usd = 0.50
mock_report.by_model = {}
mock_report.period_start.isoformat.return_value = "2024-01-01T00:00:00"
mock_report.period_end.isoformat.return_value = "2024-01-02T00:00:00"
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.get_project_usage = AsyncMock(
return_value=mock_report
)
mock_tracker.return_value.get_agent_usage = AsyncMock(
return_value=mock_report
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "get_usage",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"period": "day",
},
},
"id": 1,
},
)
assert response.status_code == 200
class TestChatCompletionTool:
"""Tests for chat_completion tool."""
def test_chat_completion_streaming_not_supported(
self,
test_client: TestClient,
) -> None:
"""Test that streaming returns info message."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.select_model = AsyncMock(
return_value=("claude-opus-4", MagicMock())
)
with patch("server.get_cost_tracker") as mock_tracker:
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.check_budget = AsyncMock(
return_value=(True, 0.0, 100.0)
)
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "chat_completion",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": True,
},
},
"id": 1,
},
)
assert response.status_code == 200
def test_chat_completion_success(self, test_client: TestClient) -> None:
"""Test successful chat completion."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello, world!"
mock_response.choices[0].finish_reason = "stop"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
with patch("server.get_model_router") as mock_router:
mock_model_config = MagicMock()
mock_model_config.provider.value = "anthropic"
mock_router.return_value = MagicMock()
mock_router.return_value.select_model = AsyncMock(
return_value=("claude-opus-4", mock_model_config)
)
with patch("server.get_cost_tracker") as mock_tracker:
mock_tracker.return_value = MagicMock()
mock_tracker.return_value.check_budget = AsyncMock(
return_value=(True, 0.0, 100.0)
)
mock_tracker.return_value.record_usage = AsyncMock()
with patch("server.get_provider") as mock_prov:
mock_prov.return_value = MagicMock()
mock_prov.return_value.router = MagicMock()
mock_prov.return_value.router.acompletion = AsyncMock(
return_value=mock_response
)
with patch("server.get_circuit_registry") as mock_reg:
mock_circuit = MagicMock()
mock_circuit.record_success = AsyncMock()
mock_reg.return_value = MagicMock()
mock_reg.return_value.get_circuit_sync.return_value = mock_circuit
response = test_client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": "chat_completion",
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"messages": [
{"role": "user", "content": "Hello"}
],
},
},
"id": 1,
},
)
assert response.status_code == 200

View File

@@ -0,0 +1,312 @@
"""
Tests for streaming module.
"""
import asyncio
import json
from unittest.mock import MagicMock
import pytest
from models import StreamChunk, UsageStats
from streaming import (
StreamAccumulator,
StreamBuffer,
format_sse_chunk,
format_sse_done,
format_sse_error,
stream_to_string,
wrap_litellm_stream,
)
class TestStreamAccumulator:
"""Tests for StreamAccumulator class."""
def test_initial_state(self) -> None:
"""Test initial accumulator state."""
acc = StreamAccumulator()
assert acc.request_id is not None
assert acc.content == ""
assert acc.chunks_received == 0
assert acc.prompt_tokens == 0
assert acc.completion_tokens == 0
assert acc.model is None
assert acc.finish_reason is None
def test_custom_request_id(self) -> None:
"""Test accumulator with custom request ID."""
acc = StreamAccumulator(request_id="custom-id")
assert acc.request_id == "custom-id"
def test_add_chunk_text(self) -> None:
"""Test adding text chunks."""
acc = StreamAccumulator()
acc.add_chunk("Hello")
acc.add_chunk(", ")
acc.add_chunk("world!")
assert acc.content == "Hello, world!"
assert acc.chunks_received == 3
def test_add_chunk_with_finish_reason(self) -> None:
"""Test adding chunk with finish reason."""
acc = StreamAccumulator()
acc.add_chunk("Final", finish_reason="stop")
assert acc.finish_reason == "stop"
def test_add_chunk_with_model(self) -> None:
"""Test adding chunk with model info."""
acc = StreamAccumulator()
acc.add_chunk("Text", model="claude-opus-4")
assert acc.model == "claude-opus-4"
def test_add_chunk_with_usage(self) -> None:
"""Test adding chunk with usage stats."""
acc = StreamAccumulator()
acc.add_chunk(
"Text",
usage={"prompt_tokens": 10, "completion_tokens": 5},
)
assert acc.prompt_tokens == 10
assert acc.completion_tokens == 5
assert acc.total_tokens == 15
def test_start_and_finish(self) -> None:
"""Test start and finish timing."""
acc = StreamAccumulator()
assert acc.duration_seconds is None
acc.start()
acc.finish()
assert acc.duration_seconds is not None
assert acc.duration_seconds >= 0
def test_get_usage_stats(self) -> None:
"""Test getting usage stats."""
acc = StreamAccumulator()
acc.add_chunk("", usage={"prompt_tokens": 100, "completion_tokens": 50})
stats = acc.get_usage_stats(cost_usd=0.01)
assert stats.prompt_tokens == 100
assert stats.completion_tokens == 50
assert stats.total_tokens == 150
assert stats.cost_usd == 0.01
class TestWrapLiteLLMStream:
"""Tests for wrap_litellm_stream function."""
async def test_wrap_stream_basic(self) -> None:
"""Test wrapping a basic stream."""
# Create mock stream chunks
async def mock_stream():
chunk1 = MagicMock()
chunk1.choices = [MagicMock()]
chunk1.choices[0].delta = MagicMock()
chunk1.choices[0].delta.content = "Hello"
chunk1.choices[0].finish_reason = None
chunk1.model = "test-model"
chunk1.usage = None
yield chunk1
chunk2 = MagicMock()
chunk2.choices = [MagicMock()]
chunk2.choices[0].delta = MagicMock()
chunk2.choices[0].delta.content = " World"
chunk2.choices[0].finish_reason = "stop"
chunk2.model = "test-model"
chunk2.usage = MagicMock()
chunk2.usage.prompt_tokens = 5
chunk2.usage.completion_tokens = 2
yield chunk2
accumulator = StreamAccumulator()
chunks = []
async for chunk in wrap_litellm_stream(mock_stream(), accumulator):
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0].delta == "Hello"
assert chunks[1].delta == " World"
assert chunks[1].finish_reason == "stop"
assert accumulator.content == "Hello World"
async def test_wrap_stream_without_accumulator(self) -> None:
"""Test wrapping stream without accumulator."""
async def mock_stream():
chunk = MagicMock()
chunk.choices = [MagicMock()]
chunk.choices[0].delta = MagicMock()
chunk.choices[0].delta.content = "Test"
chunk.choices[0].finish_reason = None
chunk.model = None
chunk.usage = None
yield chunk
chunks = []
async for chunk in wrap_litellm_stream(mock_stream()):
chunks.append(chunk)
assert len(chunks) == 1
class TestSSEFormatting:
"""Tests for SSE formatting functions."""
def test_format_sse_chunk_basic(self) -> None:
"""Test formatting basic chunk."""
chunk = StreamChunk(id="chunk-1", delta="Hello")
result = format_sse_chunk(chunk)
assert result.startswith("data: ")
assert result.endswith("\n\n")
# Parse the JSON
data = json.loads(result[6:-2])
assert data["id"] == "chunk-1"
assert data["delta"] == "Hello"
def test_format_sse_chunk_with_finish(self) -> None:
"""Test formatting chunk with finish reason."""
chunk = StreamChunk(
id="chunk-2",
delta="",
finish_reason="stop",
)
result = format_sse_chunk(chunk)
data = json.loads(result[6:-2])
assert data["finish_reason"] == "stop"
def test_format_sse_chunk_with_usage(self) -> None:
"""Test formatting chunk with usage."""
chunk = StreamChunk(
id="chunk-3",
delta="",
finish_reason="stop",
usage=UsageStats(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
cost_usd=0.001,
),
)
result = format_sse_chunk(chunk)
data = json.loads(result[6:-2])
assert "usage" in data
assert data["usage"]["prompt_tokens"] == 10
def test_format_sse_done(self) -> None:
"""Test formatting done message."""
result = format_sse_done()
assert result == "data: [DONE]\n\n"
def test_format_sse_error(self) -> None:
"""Test formatting error message."""
result = format_sse_error("Something went wrong", code="ERROR_CODE")
data = json.loads(result[6:-2])
assert data["error"] == "Something went wrong"
assert data["code"] == "ERROR_CODE"
def test_format_sse_error_without_code(self) -> None:
"""Test formatting error without code."""
result = format_sse_error("Error message")
data = json.loads(result[6:-2])
assert data["error"] == "Error message"
assert "code" not in data
class TestStreamBuffer:
"""Tests for StreamBuffer class."""
async def test_buffer_basic(self) -> None:
"""Test basic buffer operations."""
buffer = StreamBuffer(max_size=10)
# Producer
async def produce():
await buffer.put(StreamChunk(id="1", delta="Hello"))
await buffer.put(StreamChunk(id="2", delta=" World"))
await buffer.done()
# Consumer
chunks = []
asyncio.create_task(produce())
async for chunk in buffer:
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0].delta == "Hello"
assert chunks[1].delta == " World"
async def test_buffer_error(self) -> None:
"""Test buffer with error."""
buffer = StreamBuffer()
async def produce():
await buffer.put(StreamChunk(id="1", delta="Hello"))
await buffer.error(ValueError("Test error"))
asyncio.create_task(produce())
with pytest.raises(ValueError, match="Test error"):
async for _ in buffer:
pass
async def test_buffer_put_after_done(self) -> None:
"""Test putting after done raises."""
buffer = StreamBuffer()
await buffer.done()
with pytest.raises(RuntimeError, match="closed"):
await buffer.put(StreamChunk(id="1", delta="Test"))
class TestStreamToString:
"""Tests for stream_to_string function."""
async def test_stream_to_string_basic(self) -> None:
"""Test converting stream to string."""
async def mock_stream():
yield StreamChunk(id="1", delta="Hello")
yield StreamChunk(id="2", delta=" ")
yield StreamChunk(id="3", delta="World")
yield StreamChunk(
id="4",
delta="",
finish_reason="stop",
usage=UsageStats(prompt_tokens=5, completion_tokens=3),
)
content, usage = await stream_to_string(mock_stream())
assert content == "Hello World"
assert usage is not None
assert usage.prompt_tokens == 5
async def test_stream_to_string_no_usage(self) -> None:
"""Test stream without usage stats."""
async def mock_stream():
yield StreamChunk(id="1", delta="Test")
content, usage = await stream_to_string(mock_stream())
assert content == "Test"
assert usage is None

2839
mcp-servers/llm-gateway/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff