forked from cardosofelipe/fast-next-template
Merge pull request #71 from feature/56-llm-gateway-mcp-server
feat(llm-gateway): implement LLM Gateway MCP Server (#56) 🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
53
mcp-servers/llm-gateway/Dockerfile
Normal file
53
mcp-servers/llm-gateway/Dockerfile
Normal file
@@ -0,0 +1,53 @@
|
||||
# Syndarix LLM Gateway MCP Server
|
||||
# Multi-stage build for minimal image size
|
||||
|
||||
# Build stage
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
# Install uv for fast package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml ./
|
||||
|
||||
# Create virtual environment and install dependencies
|
||||
RUN uv venv /app/.venv
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
RUN uv pip install -e .
|
||||
|
||||
# Runtime stage
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
# Create non-root user for security
|
||||
RUN groupadd --gid 1000 appgroup && \
|
||||
useradd --uid 1000 --gid appgroup --shell /bin/bash --create-home appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy virtual environment from builder
|
||||
COPY --from=builder /app/.venv /app/.venv
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
# Copy application code
|
||||
COPY --chown=appuser:appgroup . .
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Environment variables
|
||||
ENV LLM_GATEWAY_HOST=0.0.0.0
|
||||
ENV LLM_GATEWAY_PORT=8001
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8001
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()" || exit 1
|
||||
|
||||
# Run the server
|
||||
CMD ["python", "server.py"]
|
||||
179
mcp-servers/llm-gateway/config.py
Normal file
179
mcp-servers/llm-gateway/config.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Configuration for LLM Gateway MCP Server.
|
||||
|
||||
Uses Pydantic Settings for type-safe environment variable handling.
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""LLM Gateway configuration settings."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="LLM_GATEWAY_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Server settings
|
||||
host: str = Field(default="0.0.0.0", description="Server host")
|
||||
port: int = Field(default=8001, description="Server port")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
# Redis settings
|
||||
redis_url: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis connection URL",
|
||||
)
|
||||
redis_prefix: str = Field(
|
||||
default="llm_gateway",
|
||||
description="Redis key prefix",
|
||||
)
|
||||
redis_ttl_hours: int = Field(
|
||||
default=24,
|
||||
description="Default Redis TTL in hours",
|
||||
)
|
||||
|
||||
# Provider API keys
|
||||
anthropic_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Anthropic API key for Claude models",
|
||||
)
|
||||
openai_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI API key for GPT models",
|
||||
)
|
||||
google_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Google API key for Gemini models",
|
||||
)
|
||||
alibaba_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Alibaba API key for Qwen models",
|
||||
)
|
||||
deepseek_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="DeepSeek API key",
|
||||
)
|
||||
deepseek_base_url: str | None = Field(
|
||||
default=None,
|
||||
description="DeepSeek API base URL (for self-hosted)",
|
||||
)
|
||||
|
||||
# LiteLLM settings
|
||||
litellm_timeout: int = Field(
|
||||
default=120,
|
||||
description="LiteLLM request timeout in seconds",
|
||||
)
|
||||
litellm_max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum retries per provider",
|
||||
)
|
||||
litellm_cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable Redis caching for LiteLLM",
|
||||
)
|
||||
litellm_cache_ttl: int = Field(
|
||||
default=3600,
|
||||
description="Cache TTL in seconds",
|
||||
)
|
||||
|
||||
# Circuit breaker settings
|
||||
circuit_failure_threshold: int = Field(
|
||||
default=5,
|
||||
description="Failures before circuit opens",
|
||||
)
|
||||
circuit_recovery_timeout: int = Field(
|
||||
default=60,
|
||||
description="Seconds before circuit half-opens",
|
||||
)
|
||||
circuit_half_open_max_calls: int = Field(
|
||||
default=3,
|
||||
description="Max calls in half-open state",
|
||||
)
|
||||
|
||||
# Cost tracking settings
|
||||
cost_tracking_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable cost tracking",
|
||||
)
|
||||
cost_alert_threshold: float = Field(
|
||||
default=100.0,
|
||||
description="Cost threshold for alerts (USD)",
|
||||
)
|
||||
default_budget_limit: float = Field(
|
||||
default=1000.0,
|
||||
description="Default project budget limit (USD)",
|
||||
)
|
||||
|
||||
# Rate limiting
|
||||
rate_limit_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable rate limiting",
|
||||
)
|
||||
rate_limit_requests_per_minute: int = Field(
|
||||
default=60,
|
||||
description="Max requests per minute per project",
|
||||
)
|
||||
|
||||
@field_validator("port")
|
||||
@classmethod
|
||||
def validate_port(cls, v: int) -> int:
|
||||
"""Validate port is in valid range."""
|
||||
if not 1 <= v <= 65535:
|
||||
raise ValueError("Port must be between 1 and 65535")
|
||||
return v
|
||||
|
||||
@field_validator("redis_ttl_hours")
|
||||
@classmethod
|
||||
def validate_ttl(cls, v: int) -> int:
|
||||
"""Validate TTL is positive."""
|
||||
if v <= 0:
|
||||
raise ValueError("Redis TTL must be positive")
|
||||
return v
|
||||
|
||||
@field_validator("circuit_failure_threshold")
|
||||
@classmethod
|
||||
def validate_failure_threshold(cls, v: int) -> int:
|
||||
"""Validate failure threshold is reasonable."""
|
||||
if not 1 <= v <= 100:
|
||||
raise ValueError("Failure threshold must be between 1 and 100")
|
||||
return v
|
||||
|
||||
@field_validator("litellm_timeout")
|
||||
@classmethod
|
||||
def validate_timeout(cls, v: int) -> int:
|
||||
"""Validate timeout is reasonable."""
|
||||
if not 1 <= v <= 600:
|
||||
raise ValueError("Timeout must be between 1 and 600 seconds")
|
||||
return v
|
||||
|
||||
def get_available_providers(self) -> list[str]:
|
||||
"""Get list of providers with configured API keys."""
|
||||
providers = []
|
||||
if self.anthropic_api_key:
|
||||
providers.append("anthropic")
|
||||
if self.openai_api_key:
|
||||
providers.append("openai")
|
||||
if self.google_api_key:
|
||||
providers.append("google")
|
||||
if self.alibaba_api_key:
|
||||
providers.append("alibaba")
|
||||
if self.deepseek_api_key or self.deepseek_base_url:
|
||||
providers.append("deepseek")
|
||||
return providers
|
||||
|
||||
def has_any_provider(self) -> bool:
|
||||
"""Check if at least one provider is configured."""
|
||||
return len(self.get_available_providers()) > 0
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings()
|
||||
467
mcp-servers/llm-gateway/cost_tracking.py
Normal file
467
mcp-servers/llm-gateway/cost_tracking.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Cost tracking for LLM Gateway.
|
||||
|
||||
Tracks LLM usage costs per project and agent using Redis.
|
||||
Provides aggregation by hour, day, and month with TTL-based expiry.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from config import Settings, get_settings
|
||||
from models import (
|
||||
MODEL_CONFIGS,
|
||||
UsageReport,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CostTracker:
|
||||
"""
|
||||
Redis-based cost tracker for LLM usage.
|
||||
|
||||
Key structure:
|
||||
- {prefix}:cost:project:{project_id}:{date} -> Hash of usage by model
|
||||
- {prefix}:cost:agent:{agent_id}:{date} -> Hash of usage by model
|
||||
- {prefix}:cost:session:{session_id} -> Hash of session usage
|
||||
- {prefix}:requests:{project_id}:{date} -> Request count
|
||||
|
||||
Date formats:
|
||||
- hour: YYYYMMDDHH
|
||||
- day: YYYYMMDD
|
||||
- month: YYYYMM
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: redis.Redis | None = None,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize cost tracker.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client (creates one if None)
|
||||
settings: Application settings
|
||||
"""
|
||||
self._settings = settings or get_settings()
|
||||
self._redis: redis.Redis | None = redis_client
|
||||
self._prefix = self._settings.redis_prefix
|
||||
|
||||
async def _get_redis(self) -> redis.Redis:
|
||||
"""Get or create Redis client."""
|
||||
if self._redis is None:
|
||||
self._redis = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
decode_responses=True,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
if self._redis:
|
||||
await self._redis.aclose()
|
||||
self._redis = None
|
||||
|
||||
def _get_date_keys(self, timestamp: datetime | None = None) -> dict[str, str]:
|
||||
"""Get date format keys for different periods."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now(UTC)
|
||||
return {
|
||||
"hour": timestamp.strftime("%Y%m%d%H"),
|
||||
"day": timestamp.strftime("%Y%m%d"),
|
||||
"month": timestamp.strftime("%Y%m"),
|
||||
}
|
||||
|
||||
def _get_ttl_seconds(self, period: str) -> int:
|
||||
"""Get TTL in seconds for a period."""
|
||||
ttls = {
|
||||
"hour": 24 * 3600, # 24 hours
|
||||
"day": 30 * 24 * 3600, # 30 days
|
||||
"month": 365 * 24 * 3600, # 1 year
|
||||
}
|
||||
return ttls.get(period, 30 * 24 * 3600)
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cost_usd: float,
|
||||
session_id: str | None = None,
|
||||
request_id: str | None = None, # noqa: ARG002 - reserved for future logging
|
||||
) -> None:
|
||||
"""
|
||||
Record LLM usage.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
agent_id: Agent ID
|
||||
model: Model name
|
||||
prompt_tokens: Input tokens
|
||||
completion_tokens: Output tokens
|
||||
cost_usd: Cost in USD
|
||||
session_id: Optional session ID
|
||||
request_id: Optional request ID
|
||||
"""
|
||||
if not self._settings.cost_tracking_enabled:
|
||||
return
|
||||
|
||||
r = await self._get_redis()
|
||||
date_keys = self._get_date_keys()
|
||||
pipe = r.pipeline()
|
||||
|
||||
# Record for each time period
|
||||
for period, date_key in date_keys.items():
|
||||
# Project-level tracking
|
||||
project_key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
|
||||
await self._increment_usage(
|
||||
pipe, project_key, model, prompt_tokens, completion_tokens, cost_usd
|
||||
)
|
||||
pipe.expire(project_key, self._get_ttl_seconds(period))
|
||||
|
||||
# Agent-level tracking
|
||||
agent_key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
|
||||
await self._increment_usage(
|
||||
pipe, agent_key, model, prompt_tokens, completion_tokens, cost_usd
|
||||
)
|
||||
pipe.expire(agent_key, self._get_ttl_seconds(period))
|
||||
|
||||
# Request counter
|
||||
requests_key = f"{self._prefix}:requests:{project_id}:{date_key}"
|
||||
pipe.incr(requests_key)
|
||||
pipe.expire(requests_key, self._get_ttl_seconds(period))
|
||||
|
||||
# Session tracking (if session_id provided)
|
||||
if session_id:
|
||||
session_key = f"{self._prefix}:cost:session:{session_id}"
|
||||
await self._increment_usage(
|
||||
pipe, session_key, model, prompt_tokens, completion_tokens, cost_usd
|
||||
)
|
||||
pipe.expire(session_key, 24 * 3600) # 24 hour TTL for sessions
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
logger.debug(
|
||||
f"Recorded usage: project={project_id}, agent={agent_id}, "
|
||||
f"model={model}, tokens={prompt_tokens + completion_tokens}, "
|
||||
f"cost=${cost_usd:.6f}"
|
||||
)
|
||||
|
||||
async def _increment_usage(
|
||||
self,
|
||||
pipe: redis.client.Pipeline,
|
||||
key: str,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cost_usd: float,
|
||||
) -> None:
|
||||
"""Increment usage in a hash."""
|
||||
# Store as JSON fields within the hash
|
||||
pipe.hincrby(key, f"{model}:prompt_tokens", prompt_tokens)
|
||||
pipe.hincrby(key, f"{model}:completion_tokens", completion_tokens)
|
||||
pipe.hincrbyfloat(key, f"{model}:cost_usd", cost_usd)
|
||||
pipe.hincrby(key, f"{model}:requests", 1)
|
||||
# Totals
|
||||
pipe.hincrby(key, "total:prompt_tokens", prompt_tokens)
|
||||
pipe.hincrby(key, "total:completion_tokens", completion_tokens)
|
||||
pipe.hincrbyfloat(key, "total:cost_usd", cost_usd)
|
||||
pipe.hincrby(key, "total:requests", 1)
|
||||
|
||||
async def get_project_usage(
|
||||
self,
|
||||
project_id: str,
|
||||
period: str = "day",
|
||||
timestamp: datetime | None = None,
|
||||
) -> UsageReport:
|
||||
"""
|
||||
Get usage report for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
period: Time period (hour, day, month)
|
||||
timestamp: Specific time to query (defaults to now)
|
||||
|
||||
Returns:
|
||||
Usage report
|
||||
"""
|
||||
date_keys = self._get_date_keys(timestamp)
|
||||
date_key = date_keys.get(period, date_keys["day"])
|
||||
|
||||
key = f"{self._prefix}:cost:project:{project_id}:{date_key}"
|
||||
return await self._get_usage_report(
|
||||
key, project_id, "project", period, timestamp
|
||||
)
|
||||
|
||||
async def get_agent_usage(
|
||||
self,
|
||||
agent_id: str,
|
||||
period: str = "day",
|
||||
timestamp: datetime | None = None,
|
||||
) -> UsageReport:
|
||||
"""
|
||||
Get usage report for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
period: Time period (hour, day, month)
|
||||
timestamp: Specific time to query (defaults to now)
|
||||
|
||||
Returns:
|
||||
Usage report
|
||||
"""
|
||||
date_keys = self._get_date_keys(timestamp)
|
||||
date_key = date_keys.get(period, date_keys["day"])
|
||||
|
||||
key = f"{self._prefix}:cost:agent:{agent_id}:{date_key}"
|
||||
return await self._get_usage_report(key, agent_id, "agent", period, timestamp)
|
||||
|
||||
async def _get_usage_report(
|
||||
self,
|
||||
key: str,
|
||||
entity_id: str,
|
||||
entity_type: str,
|
||||
period: str,
|
||||
timestamp: datetime | None,
|
||||
) -> UsageReport:
|
||||
"""Get usage report from a Redis hash."""
|
||||
r = await self._get_redis()
|
||||
data = await r.hgetall(key) # type: ignore[misc]
|
||||
|
||||
# Parse the hash data
|
||||
by_model: dict[str, dict[str, Any]] = {}
|
||||
total_requests = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
|
||||
for field, value in data.items():
|
||||
parts = field.split(":")
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
model, metric = parts
|
||||
if model == "total":
|
||||
if metric == "requests":
|
||||
total_requests = int(value)
|
||||
elif metric == "prompt_tokens" or metric == "completion_tokens":
|
||||
total_tokens += int(value)
|
||||
elif metric == "cost_usd":
|
||||
total_cost = float(value)
|
||||
else:
|
||||
if model not in by_model:
|
||||
by_model[model] = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"cost_usd": 0.0,
|
||||
"requests": 0,
|
||||
}
|
||||
if metric == "prompt_tokens":
|
||||
by_model[model]["prompt_tokens"] = int(value)
|
||||
elif metric == "completion_tokens":
|
||||
by_model[model]["completion_tokens"] = int(value)
|
||||
elif metric == "cost_usd":
|
||||
by_model[model]["cost_usd"] = float(value)
|
||||
elif metric == "requests":
|
||||
by_model[model]["requests"] = int(value)
|
||||
|
||||
# Calculate period boundaries
|
||||
now = timestamp or datetime.now(UTC)
|
||||
if period == "hour":
|
||||
period_start = now.replace(minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + timedelta(hours=1)
|
||||
elif period == "day":
|
||||
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + timedelta(days=1)
|
||||
else: # month
|
||||
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
# Next month
|
||||
if now.month == 12:
|
||||
period_end = period_start.replace(year=now.year + 1, month=1)
|
||||
else:
|
||||
period_end = period_start.replace(month=now.month + 1)
|
||||
|
||||
return UsageReport(
|
||||
entity_id=entity_id,
|
||||
entity_type=entity_type,
|
||||
period=period,
|
||||
period_start=period_start,
|
||||
period_end=period_end,
|
||||
total_requests=total_requests,
|
||||
total_tokens=total_tokens,
|
||||
total_cost_usd=round(total_cost, 6),
|
||||
by_model=by_model,
|
||||
)
|
||||
|
||||
async def get_session_usage(self, session_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get usage for a specific session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
Session usage data
|
||||
"""
|
||||
r = await self._get_redis()
|
||||
key = f"{self._prefix}:cost:session:{session_id}"
|
||||
data = await r.hgetall(key) # type: ignore[misc]
|
||||
|
||||
# Parse similar to _get_usage_report
|
||||
result: dict[str, Any] = {
|
||||
"session_id": session_id,
|
||||
"total_tokens": 0,
|
||||
"total_cost_usd": 0.0,
|
||||
"by_model": {},
|
||||
}
|
||||
|
||||
for field, value in data.items():
|
||||
parts = field.split(":")
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
model, metric = parts
|
||||
if model == "total":
|
||||
if metric == "prompt_tokens" or metric == "completion_tokens":
|
||||
result["total_tokens"] += int(value)
|
||||
elif metric == "cost_usd":
|
||||
result["total_cost_usd"] = float(value)
|
||||
else:
|
||||
if model not in result["by_model"]:
|
||||
result["by_model"][model] = {}
|
||||
if metric in ("prompt_tokens", "completion_tokens", "requests"):
|
||||
result["by_model"][model][metric] = int(value)
|
||||
elif metric == "cost_usd":
|
||||
result["by_model"][model][metric] = float(value)
|
||||
|
||||
return result
|
||||
|
||||
async def check_budget(
|
||||
self,
|
||||
project_id: str,
|
||||
budget_limit: float | None = None,
|
||||
) -> tuple[bool, float, float]:
|
||||
"""
|
||||
Check if project is within budget.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
budget_limit: Budget limit (uses default if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (within_budget, current_cost, limit)
|
||||
"""
|
||||
limit = budget_limit or self._settings.default_budget_limit
|
||||
|
||||
# Get current month usage
|
||||
report = await self.get_project_usage(project_id, period="month")
|
||||
current_cost = report.total_cost_usd
|
||||
|
||||
within_budget = current_cost < limit
|
||||
return within_budget, current_cost, limit
|
||||
|
||||
async def estimate_request_cost(
|
||||
self,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
max_completion_tokens: int,
|
||||
) -> float:
|
||||
"""
|
||||
Estimate cost for a request.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
prompt_tokens: Input token count
|
||||
max_completion_tokens: Maximum output tokens
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD
|
||||
"""
|
||||
config = MODEL_CONFIGS.get(model)
|
||||
if not config:
|
||||
# Use a default estimate
|
||||
return (prompt_tokens + max_completion_tokens) * 0.00001
|
||||
|
||||
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
|
||||
output_cost = (max_completion_tokens / 1_000_000) * config.cost_per_1m_output
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
async def should_alert(
|
||||
self,
|
||||
project_id: str,
|
||||
threshold: float | None = None,
|
||||
) -> tuple[bool, float]:
|
||||
"""
|
||||
Check if cost alert should be triggered.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
threshold: Alert threshold (uses default if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (should_alert, current_cost)
|
||||
"""
|
||||
thresh = threshold or self._settings.cost_alert_threshold
|
||||
|
||||
report = await self.get_project_usage(project_id, period="day")
|
||||
current_cost = report.total_cost_usd
|
||||
|
||||
return current_cost >= thresh, current_cost
|
||||
|
||||
|
||||
def calculate_cost(
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate cost for a completion.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
prompt_tokens: Input tokens
|
||||
completion_tokens: Output tokens
|
||||
|
||||
Returns:
|
||||
Cost in USD
|
||||
"""
|
||||
config = MODEL_CONFIGS.get(model)
|
||||
if not config:
|
||||
logger.warning(f"Unknown model {model} for cost calculation")
|
||||
return 0.0
|
||||
|
||||
input_cost = (prompt_tokens / 1_000_000) * config.cost_per_1m_input
|
||||
output_cost = (completion_tokens / 1_000_000) * config.cost_per_1m_output
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
|
||||
# Global tracker instance (lazy initialization)
|
||||
_tracker: CostTracker | None = None
|
||||
|
||||
|
||||
def get_cost_tracker() -> CostTracker:
|
||||
"""Get the global cost tracker instance."""
|
||||
global _tracker
|
||||
if _tracker is None:
|
||||
_tracker = CostTracker()
|
||||
return _tracker
|
||||
|
||||
|
||||
async def close_cost_tracker() -> None:
|
||||
"""Close the global cost tracker."""
|
||||
global _tracker
|
||||
if _tracker:
|
||||
await _tracker.close()
|
||||
_tracker = None
|
||||
|
||||
|
||||
def reset_cost_tracker() -> None:
|
||||
"""Reset the global tracker (for testing)."""
|
||||
global _tracker
|
||||
_tracker = None
|
||||
476
mcp-servers/llm-gateway/exceptions.py
Normal file
476
mcp-servers/llm-gateway/exceptions.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
Custom exceptions for LLM Gateway MCP Server.
|
||||
|
||||
Provides structured error handling with error codes for consistent responses.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
"""Error codes for LLM Gateway errors."""
|
||||
|
||||
# General errors
|
||||
UNKNOWN_ERROR = "LLM_UNKNOWN_ERROR"
|
||||
INVALID_REQUEST = "LLM_INVALID_REQUEST"
|
||||
CONFIGURATION_ERROR = "LLM_CONFIGURATION_ERROR"
|
||||
|
||||
# Provider errors
|
||||
PROVIDER_ERROR = "LLM_PROVIDER_ERROR"
|
||||
PROVIDER_TIMEOUT = "LLM_PROVIDER_TIMEOUT"
|
||||
PROVIDER_RATE_LIMIT = "LLM_PROVIDER_RATE_LIMIT"
|
||||
PROVIDER_UNAVAILABLE = "LLM_PROVIDER_UNAVAILABLE"
|
||||
ALL_PROVIDERS_FAILED = "LLM_ALL_PROVIDERS_FAILED"
|
||||
|
||||
# Model errors
|
||||
INVALID_MODEL = "LLM_INVALID_MODEL"
|
||||
INVALID_MODEL_GROUP = "LLM_INVALID_MODEL_GROUP"
|
||||
MODEL_NOT_AVAILABLE = "LLM_MODEL_NOT_AVAILABLE"
|
||||
|
||||
# Circuit breaker errors
|
||||
CIRCUIT_OPEN = "LLM_CIRCUIT_OPEN"
|
||||
CIRCUIT_HALF_OPEN_EXHAUSTED = "LLM_CIRCUIT_HALF_OPEN_EXHAUSTED"
|
||||
|
||||
# Cost errors
|
||||
COST_LIMIT_EXCEEDED = "LLM_COST_LIMIT_EXCEEDED"
|
||||
BUDGET_EXHAUSTED = "LLM_BUDGET_EXHAUSTED"
|
||||
|
||||
# Rate limiting errors
|
||||
RATE_LIMIT_EXCEEDED = "LLM_RATE_LIMIT_EXCEEDED"
|
||||
|
||||
# Streaming errors
|
||||
STREAM_ERROR = "LLM_STREAM_ERROR"
|
||||
STREAM_INTERRUPTED = "LLM_STREAM_INTERRUPTED"
|
||||
|
||||
# Token errors
|
||||
TOKEN_LIMIT_EXCEEDED = "LLM_TOKEN_LIMIT_EXCEEDED"
|
||||
CONTEXT_TOO_LONG = "LLM_CONTEXT_TOO_LONG"
|
||||
|
||||
|
||||
class LLMGatewayError(Exception):
|
||||
"""Base exception for LLM Gateway errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.UNKNOWN_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize LLM Gateway error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
code: Error code for programmatic handling
|
||||
details: Additional error details
|
||||
cause: Original exception that caused this error
|
||||
"""
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details or {}
|
||||
self.cause = cause
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert error to dictionary for JSON response."""
|
||||
result: dict[str, Any] = {
|
||||
"error": self.code.value,
|
||||
"message": self.message,
|
||||
}
|
||||
if self.details:
|
||||
result["details"] = self.details
|
||||
return result
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation."""
|
||||
return f"[{self.code.value}] {self.message}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Detailed representation."""
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"message={self.message!r}, "
|
||||
f"code={self.code.value!r}, "
|
||||
f"details={self.details!r})"
|
||||
)
|
||||
|
||||
|
||||
class ProviderError(LLMGatewayError):
|
||||
"""Error from an LLM provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider: str,
|
||||
model: str | None = None,
|
||||
status_code: int | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize provider error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
provider: Provider that failed
|
||||
model: Model that was being used
|
||||
status_code: HTTP status code if applicable
|
||||
details: Additional details
|
||||
cause: Original exception
|
||||
"""
|
||||
error_details = details or {}
|
||||
error_details["provider"] = provider
|
||||
if model:
|
||||
error_details["model"] = model
|
||||
if status_code:
|
||||
error_details["status_code"] = status_code
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=ErrorCode.PROVIDER_ERROR,
|
||||
details=error_details,
|
||||
cause=cause,
|
||||
)
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class RateLimitError(LLMGatewayError):
|
||||
"""Rate limit exceeded error."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider: str | None = None,
|
||||
retry_after: int | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize rate limit error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
provider: Provider that rate limited (None for internal limit)
|
||||
retry_after: Seconds until retry is allowed
|
||||
details: Additional details
|
||||
"""
|
||||
error_details = details or {}
|
||||
if provider:
|
||||
error_details["provider"] = provider
|
||||
if retry_after:
|
||||
error_details["retry_after_seconds"] = retry_after
|
||||
|
||||
code = (
|
||||
ErrorCode.PROVIDER_RATE_LIMIT if provider else ErrorCode.RATE_LIMIT_EXCEEDED
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=code,
|
||||
details=error_details,
|
||||
)
|
||||
self.provider = provider
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class CircuitOpenError(LLMGatewayError):
|
||||
"""Circuit breaker is open, provider temporarily unavailable."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str,
|
||||
recovery_time: int | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize circuit open error.
|
||||
|
||||
Args:
|
||||
provider: Provider with open circuit
|
||||
recovery_time: Seconds until circuit may recover
|
||||
details: Additional details
|
||||
"""
|
||||
error_details = details or {}
|
||||
error_details["provider"] = provider
|
||||
if recovery_time:
|
||||
error_details["recovery_time_seconds"] = recovery_time
|
||||
|
||||
super().__init__(
|
||||
message=f"Circuit breaker open for provider {provider}",
|
||||
code=ErrorCode.CIRCUIT_OPEN,
|
||||
details=error_details,
|
||||
)
|
||||
self.provider = provider
|
||||
self.recovery_time = recovery_time
|
||||
|
||||
|
||||
class CostLimitExceededError(LLMGatewayError):
|
||||
"""Cost limit exceeded for project or agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
current_cost: float,
|
||||
limit: float,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize cost limit error.
|
||||
|
||||
Args:
|
||||
entity_type: 'project' or 'agent'
|
||||
entity_id: ID of the entity
|
||||
current_cost: Current accumulated cost
|
||||
limit: Cost limit that was exceeded
|
||||
details: Additional details
|
||||
"""
|
||||
error_details = details or {}
|
||||
error_details["entity_type"] = entity_type
|
||||
error_details["entity_id"] = entity_id
|
||||
error_details["current_cost_usd"] = current_cost
|
||||
error_details["limit_usd"] = limit
|
||||
|
||||
super().__init__(
|
||||
message=(
|
||||
f"Cost limit exceeded for {entity_type} {entity_id}: "
|
||||
f"${current_cost:.2f} >= ${limit:.2f}"
|
||||
),
|
||||
code=ErrorCode.COST_LIMIT_EXCEEDED,
|
||||
details=error_details,
|
||||
)
|
||||
self.entity_type = entity_type
|
||||
self.entity_id = entity_id
|
||||
self.current_cost = current_cost
|
||||
self.limit = limit
|
||||
|
||||
|
||||
class InvalidModelGroupError(LLMGatewayError):
|
||||
"""Invalid or unknown model group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_group: str,
|
||||
available_groups: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize invalid model group error.
|
||||
|
||||
Args:
|
||||
model_group: The invalid group name
|
||||
available_groups: List of valid group names
|
||||
"""
|
||||
details: dict[str, Any] = {"requested_group": model_group}
|
||||
if available_groups:
|
||||
details["available_groups"] = available_groups
|
||||
|
||||
super().__init__(
|
||||
message=f"Invalid model group: {model_group}",
|
||||
code=ErrorCode.INVALID_MODEL_GROUP,
|
||||
details=details,
|
||||
)
|
||||
self.model_group = model_group
|
||||
self.available_groups = available_groups
|
||||
|
||||
|
||||
class InvalidModelError(LLMGatewayError):
|
||||
"""Invalid or unknown model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
reason: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize invalid model error.
|
||||
|
||||
Args:
|
||||
model: The invalid model name
|
||||
reason: Reason why it's invalid
|
||||
"""
|
||||
details: dict[str, Any] = {"requested_model": model}
|
||||
if reason:
|
||||
details["reason"] = reason
|
||||
|
||||
super().__init__(
|
||||
message=f"Invalid model: {model}" + (f" ({reason})" if reason else ""),
|
||||
code=ErrorCode.INVALID_MODEL,
|
||||
details=details,
|
||||
)
|
||||
self.model = model
|
||||
|
||||
|
||||
class ModelNotAvailableError(LLMGatewayError):
|
||||
"""Model not available (provider not configured)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model not available error.
|
||||
|
||||
Args:
|
||||
model: The unavailable model
|
||||
provider: The provider that's not configured
|
||||
"""
|
||||
super().__init__(
|
||||
message=f"Model {model} not available: {provider} provider not configured",
|
||||
code=ErrorCode.MODEL_NOT_AVAILABLE,
|
||||
details={"model": model, "provider": provider},
|
||||
)
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
|
||||
|
||||
class AllProvidersFailedError(LLMGatewayError):
|
||||
"""All providers in the failover chain failed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_group: str,
|
||||
attempted_models: list[str],
|
||||
errors: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize all providers failed error.
|
||||
|
||||
Args:
|
||||
model_group: The model group that was requested
|
||||
attempted_models: Models that were attempted
|
||||
errors: Errors from each attempt
|
||||
"""
|
||||
super().__init__(
|
||||
message=f"All providers failed for model group {model_group}",
|
||||
code=ErrorCode.ALL_PROVIDERS_FAILED,
|
||||
details={
|
||||
"model_group": model_group,
|
||||
"attempted_models": attempted_models,
|
||||
"errors": errors,
|
||||
},
|
||||
)
|
||||
self.model_group = model_group
|
||||
self.attempted_models = attempted_models
|
||||
self.errors = errors
|
||||
|
||||
|
||||
class StreamError(LLMGatewayError):
|
||||
"""Error during streaming response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
chunks_received: int = 0,
|
||||
cause: Exception | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize stream error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
chunks_received: Number of chunks received before error
|
||||
cause: Original exception
|
||||
"""
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=ErrorCode.STREAM_ERROR,
|
||||
details={"chunks_received": chunks_received},
|
||||
cause=cause,
|
||||
)
|
||||
self.chunks_received = chunks_received
|
||||
|
||||
|
||||
class TokenLimitExceededError(LLMGatewayError):
|
||||
"""Request exceeds model's token limit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
token_count: int,
|
||||
limit: int,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize token limit error.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
token_count: Requested token count
|
||||
limit: Model's token limit
|
||||
"""
|
||||
super().__init__(
|
||||
message=f"Token count {token_count} exceeds {model} limit of {limit}",
|
||||
code=ErrorCode.TOKEN_LIMIT_EXCEEDED,
|
||||
details={
|
||||
"model": model,
|
||||
"requested_tokens": token_count,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
self.model = model
|
||||
self.token_count = token_count
|
||||
self.limit = limit
|
||||
|
||||
|
||||
class ContextTooLongError(LLMGatewayError):
|
||||
"""Input context exceeds model's context window."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
context_length: int,
|
||||
max_context: int,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context too long error.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
context_length: Input context length
|
||||
max_context: Model's max context window
|
||||
"""
|
||||
super().__init__(
|
||||
message=(
|
||||
f"Context length {context_length} exceeds {model} "
|
||||
f"context window of {max_context}"
|
||||
),
|
||||
code=ErrorCode.CONTEXT_TOO_LONG,
|
||||
details={
|
||||
"model": model,
|
||||
"context_length": context_length,
|
||||
"max_context": max_context,
|
||||
},
|
||||
)
|
||||
self.model = model
|
||||
self.context_length = context_length
|
||||
self.max_context = max_context
|
||||
|
||||
|
||||
class ConfigurationError(LLMGatewayError):
|
||||
"""Configuration error."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
config_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize configuration error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
config_key: Configuration key that's problematic
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if config_key:
|
||||
details["config_key"] = config_key
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
code=ErrorCode.CONFIGURATION_ERROR,
|
||||
details=details,
|
||||
)
|
||||
self.config_key = config_key
|
||||
357
mcp-servers/llm-gateway/failover.py
Normal file
357
mcp-servers/llm-gateway/failover.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
Circuit Breaker implementation for LLM Gateway.
|
||||
|
||||
Provides fault tolerance by tracking provider failures and
|
||||
temporarily disabling providers that are experiencing issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import CircuitOpenError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class CircuitState(str, Enum):
|
||||
"""Circuit breaker states."""
|
||||
|
||||
CLOSED = "closed" # Normal operation, requests pass through
|
||||
OPEN = "open" # Failures exceeded threshold, requests blocked
|
||||
HALF_OPEN = "half_open" # Testing if service recovered
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitStats:
|
||||
"""Statistics for a circuit breaker."""
|
||||
|
||||
failures: int = 0
|
||||
successes: int = 0
|
||||
last_failure_time: float | None = None
|
||||
last_success_time: float | None = None
|
||||
state_changed_at: float = field(default_factory=time.time)
|
||||
half_open_calls: int = 0
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker for individual providers.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation. Failures increment counter.
|
||||
- OPEN: Too many failures. Requests immediately fail.
|
||||
- HALF_OPEN: Testing recovery. Limited requests allowed.
|
||||
|
||||
Transitions:
|
||||
- CLOSED -> OPEN: When failures >= threshold
|
||||
- OPEN -> HALF_OPEN: After recovery_timeout
|
||||
- HALF_OPEN -> CLOSED: On success
|
||||
- HALF_OPEN -> OPEN: On failure
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: int = 60,
|
||||
half_open_max_calls: int = 3,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
name: Identifier for this circuit (usually provider name)
|
||||
failure_threshold: Failures before opening circuit
|
||||
recovery_timeout: Seconds before attempting recovery
|
||||
half_open_max_calls: Max calls allowed in half-open state
|
||||
"""
|
||||
self.name = name
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.half_open_max_calls = half_open_max_calls
|
||||
|
||||
self._state = CircuitState.CLOSED
|
||||
self._stats = CircuitStats()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
"""Get current circuit state (may trigger state transition)."""
|
||||
self._check_state_transition()
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def stats(self) -> CircuitStats:
|
||||
"""Get circuit statistics."""
|
||||
return self._stats
|
||||
|
||||
def _check_state_transition(self) -> None:
|
||||
"""Check if state should transition based on time."""
|
||||
if self._state == CircuitState.OPEN:
|
||||
time_in_open = time.time() - self._stats.state_changed_at
|
||||
if time_in_open >= self.recovery_timeout:
|
||||
self._transition_to(CircuitState.HALF_OPEN)
|
||||
logger.info(
|
||||
f"Circuit {self.name} transitioned to HALF_OPEN "
|
||||
f"after {time_in_open:.1f}s"
|
||||
)
|
||||
|
||||
def _transition_to(self, new_state: CircuitState) -> None:
|
||||
"""Transition to a new state."""
|
||||
old_state = self._state
|
||||
self._state = new_state
|
||||
self._stats.state_changed_at = time.time()
|
||||
|
||||
if new_state == CircuitState.HALF_OPEN:
|
||||
self._stats.half_open_calls = 0
|
||||
elif new_state == CircuitState.CLOSED:
|
||||
self._stats.failures = 0
|
||||
|
||||
logger.debug(f"Circuit {self.name}: {old_state.value} -> {new_state.value}")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if circuit allows requests."""
|
||||
state = self.state # Triggers state check
|
||||
if state == CircuitState.CLOSED:
|
||||
return True
|
||||
if state == CircuitState.HALF_OPEN:
|
||||
return self._stats.half_open_calls < self.half_open_max_calls
|
||||
return False
|
||||
|
||||
def time_until_recovery(self) -> int | None:
|
||||
"""Get seconds until circuit may recover (None if not open)."""
|
||||
if self._state != CircuitState.OPEN:
|
||||
return None
|
||||
elapsed = time.time() - self._stats.state_changed_at
|
||||
remaining = max(0, self.recovery_timeout - int(elapsed))
|
||||
return remaining if remaining > 0 else 0
|
||||
|
||||
async def record_success(self) -> None:
|
||||
"""Record a successful call."""
|
||||
async with self._lock:
|
||||
self._stats.successes += 1
|
||||
self._stats.last_success_time = time.time()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# Success in half-open state closes the circuit
|
||||
self._transition_to(CircuitState.CLOSED)
|
||||
logger.info(f"Circuit {self.name} closed after successful recovery")
|
||||
|
||||
async def record_failure(self, error: Exception | None = None) -> None: # noqa: ARG002
|
||||
"""
|
||||
Record a failed call.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred
|
||||
"""
|
||||
async with self._lock:
|
||||
self._stats.failures += 1
|
||||
self._stats.last_failure_time = time.time()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# Failure in half-open state opens the circuit
|
||||
self._transition_to(CircuitState.OPEN)
|
||||
logger.warning(
|
||||
f"Circuit {self.name} reopened after failure in half-open state"
|
||||
)
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
if self._stats.failures >= self.failure_threshold:
|
||||
self._transition_to(CircuitState.OPEN)
|
||||
logger.warning(
|
||||
f"Circuit {self.name} opened after "
|
||||
f"{self._stats.failures} failures"
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
func: Callable[..., Awaitable[T]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""
|
||||
Execute a function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result
|
||||
|
||||
Raises:
|
||||
CircuitOpenError: If circuit is open
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise CircuitOpenError(
|
||||
provider=self.name,
|
||||
recovery_time=self.time_until_recovery(),
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
self._stats.half_open_calls += 1
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
await self.record_success()
|
||||
return result
|
||||
except Exception as e:
|
||||
await self.record_failure(e)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset circuit to closed state."""
|
||||
self._state = CircuitState.CLOSED
|
||||
self._stats = CircuitStats()
|
||||
logger.info(f"Circuit {self.name} reset to CLOSED")
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert circuit state to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self._state.value,
|
||||
"failures": self._stats.failures,
|
||||
"successes": self._stats.successes,
|
||||
"last_failure_time": self._stats.last_failure_time,
|
||||
"last_success_time": self._stats.last_success_time,
|
||||
"time_until_recovery": self.time_until_recovery(),
|
||||
"is_available": self.is_available(),
|
||||
}
|
||||
|
||||
|
||||
class CircuitBreakerRegistry:
|
||||
"""
|
||||
Registry for managing multiple circuit breakers.
|
||||
|
||||
Provides centralized management of circuits for different providers/models.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""
|
||||
Initialize registry.
|
||||
|
||||
Args:
|
||||
settings: Application settings (uses default if None)
|
||||
"""
|
||||
self._settings = settings or get_settings()
|
||||
self._circuits: dict[str, CircuitBreaker] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_circuit(self, name: str) -> CircuitBreaker:
|
||||
"""
|
||||
Get or create a circuit breaker.
|
||||
|
||||
Args:
|
||||
name: Circuit identifier (e.g., provider name)
|
||||
|
||||
Returns:
|
||||
CircuitBreaker instance
|
||||
"""
|
||||
async with self._lock:
|
||||
if name not in self._circuits:
|
||||
self._circuits[name] = CircuitBreaker(
|
||||
name=name,
|
||||
failure_threshold=self._settings.circuit_failure_threshold,
|
||||
recovery_timeout=self._settings.circuit_recovery_timeout,
|
||||
half_open_max_calls=self._settings.circuit_half_open_max_calls,
|
||||
)
|
||||
return self._circuits[name]
|
||||
|
||||
def get_circuit_sync(self, name: str) -> CircuitBreaker:
|
||||
"""
|
||||
Get or create a circuit breaker (sync version).
|
||||
|
||||
Args:
|
||||
name: Circuit identifier
|
||||
|
||||
Returns:
|
||||
CircuitBreaker instance
|
||||
"""
|
||||
if name not in self._circuits:
|
||||
self._circuits[name] = CircuitBreaker(
|
||||
name=name,
|
||||
failure_threshold=self._settings.circuit_failure_threshold,
|
||||
recovery_timeout=self._settings.circuit_recovery_timeout,
|
||||
half_open_max_calls=self._settings.circuit_half_open_max_calls,
|
||||
)
|
||||
return self._circuits[name]
|
||||
|
||||
async def is_available(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a circuit is available.
|
||||
|
||||
Args:
|
||||
name: Circuit identifier
|
||||
|
||||
Returns:
|
||||
True if circuit allows requests
|
||||
"""
|
||||
circuit = await self.get_circuit(name)
|
||||
return circuit.is_available()
|
||||
|
||||
async def record_success(self, name: str) -> None:
|
||||
"""Record success for a circuit."""
|
||||
circuit = await self.get_circuit(name)
|
||||
await circuit.record_success()
|
||||
|
||||
async def record_failure(self, name: str, error: Exception | None = None) -> None:
|
||||
"""Record failure for a circuit."""
|
||||
circuit = await self.get_circuit(name)
|
||||
await circuit.record_failure(error)
|
||||
|
||||
async def reset(self, name: str) -> None:
|
||||
"""Reset a specific circuit."""
|
||||
async with self._lock:
|
||||
if name in self._circuits:
|
||||
self._circuits[name].reset()
|
||||
|
||||
async def reset_all(self) -> None:
|
||||
"""Reset all circuits."""
|
||||
async with self._lock:
|
||||
for circuit in self._circuits.values():
|
||||
circuit.reset()
|
||||
|
||||
def get_all_states(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get state of all circuits."""
|
||||
return {name: circuit.to_dict() for name, circuit in self._circuits.items()}
|
||||
|
||||
def get_open_circuits(self) -> list[str]:
|
||||
"""Get list of circuits that are currently open."""
|
||||
return [
|
||||
name
|
||||
for name, circuit in self._circuits.items()
|
||||
if circuit.state == CircuitState.OPEN
|
||||
]
|
||||
|
||||
def get_available_circuits(self) -> list[str]:
|
||||
"""Get list of circuits that are available for requests."""
|
||||
return [
|
||||
name for name, circuit in self._circuits.items() if circuit.is_available()
|
||||
]
|
||||
|
||||
|
||||
# Global registry instance (lazy initialization)
|
||||
_registry: CircuitBreakerRegistry | None = None
|
||||
|
||||
|
||||
def get_circuit_registry() -> CircuitBreakerRegistry:
|
||||
"""Get the global circuit breaker registry."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = CircuitBreakerRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
def reset_circuit_registry() -> None:
|
||||
"""Reset the global registry (for testing)."""
|
||||
global _registry
|
||||
_registry = None
|
||||
442
mcp-servers/llm-gateway/models.py
Normal file
442
mcp-servers/llm-gateway/models.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
Data models for LLM Gateway MCP Server.
|
||||
|
||||
Defines model groups, pricing, request/response structures.
|
||||
Per ADR-004: LLM Provider Abstraction.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ModelGroup(str, Enum):
|
||||
"""Model groups for routing LLM requests."""
|
||||
|
||||
REASONING = "reasoning" # Complex analysis, architecture decisions
|
||||
CODE = "code" # Code writing and refactoring
|
||||
FAST = "fast" # Quick tasks, simple queries
|
||||
VISION = "vision" # Multimodal image analysis
|
||||
EMBEDDING = "embedding" # Vector embeddings
|
||||
COST_OPTIMIZED = "cost_optimized" # High-volume, non-critical
|
||||
SELF_HOSTED = "self_hosted" # Privacy-sensitive, air-gapped
|
||||
|
||||
# Aliases for backward compatibility with ADR-004
|
||||
HIGH_REASONING = "reasoning"
|
||||
CODE_GENERATION = "code"
|
||||
FAST_RESPONSE = "fast"
|
||||
|
||||
|
||||
class Provider(str, Enum):
|
||||
"""Supported LLM providers."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENAI = "openai"
|
||||
GOOGLE = "google"
|
||||
ALIBABA = "alibaba"
|
||||
DEEPSEEK = "deepseek"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for a specific model."""
|
||||
|
||||
name: str # Model identifier (e.g., "claude-3-opus-20240229")
|
||||
litellm_name: str # LiteLLM model string (e.g., "anthropic/claude-3-opus-20240229")
|
||||
provider: Provider
|
||||
cost_per_1m_input: float # USD per 1M input tokens
|
||||
cost_per_1m_output: float # USD per 1M output tokens
|
||||
context_window: int # Max context tokens
|
||||
max_output_tokens: int # Max output tokens
|
||||
supports_vision: bool = False
|
||||
supports_streaming: bool = True
|
||||
supports_function_calling: bool = True
|
||||
|
||||
|
||||
# Model configurations per ADR-004
|
||||
MODEL_CONFIGS: dict[str, ModelConfig] = {
|
||||
# Anthropic models
|
||||
"claude-opus-4": ModelConfig(
|
||||
name="claude-opus-4",
|
||||
litellm_name="anthropic/claude-sonnet-4-20250514", # Using sonnet-4 as opus-4 placeholder
|
||||
provider=Provider.ANTHROPIC,
|
||||
cost_per_1m_input=15.0,
|
||||
cost_per_1m_output=75.0,
|
||||
context_window=200000,
|
||||
max_output_tokens=8192,
|
||||
supports_vision=True,
|
||||
),
|
||||
"claude-sonnet-4": ModelConfig(
|
||||
name="claude-sonnet-4",
|
||||
litellm_name="anthropic/claude-sonnet-4-20250514",
|
||||
provider=Provider.ANTHROPIC,
|
||||
cost_per_1m_input=3.0,
|
||||
cost_per_1m_output=15.0,
|
||||
context_window=200000,
|
||||
max_output_tokens=8192,
|
||||
supports_vision=True,
|
||||
),
|
||||
"claude-haiku": ModelConfig(
|
||||
name="claude-haiku",
|
||||
litellm_name="anthropic/claude-3-5-haiku-20241022",
|
||||
provider=Provider.ANTHROPIC,
|
||||
cost_per_1m_input=1.0,
|
||||
cost_per_1m_output=5.0,
|
||||
context_window=200000,
|
||||
max_output_tokens=8192,
|
||||
supports_vision=True,
|
||||
),
|
||||
# OpenAI models
|
||||
"gpt-4.1": ModelConfig(
|
||||
name="gpt-4.1",
|
||||
litellm_name="openai/gpt-4.1",
|
||||
provider=Provider.OPENAI,
|
||||
cost_per_1m_input=2.0,
|
||||
cost_per_1m_output=8.0,
|
||||
context_window=1047576,
|
||||
max_output_tokens=32768,
|
||||
supports_vision=True,
|
||||
),
|
||||
"gpt-4.1-mini": ModelConfig(
|
||||
name="gpt-4.1-mini",
|
||||
litellm_name="openai/gpt-4.1-mini",
|
||||
provider=Provider.OPENAI,
|
||||
cost_per_1m_input=0.4,
|
||||
cost_per_1m_output=1.6,
|
||||
context_window=1047576,
|
||||
max_output_tokens=32768,
|
||||
supports_vision=True,
|
||||
),
|
||||
# Google models
|
||||
"gemini-2.5-pro": ModelConfig(
|
||||
name="gemini-2.5-pro",
|
||||
litellm_name="gemini/gemini-2.5-pro",
|
||||
provider=Provider.GOOGLE,
|
||||
cost_per_1m_input=1.25,
|
||||
cost_per_1m_output=10.0,
|
||||
context_window=1048576,
|
||||
max_output_tokens=65536,
|
||||
supports_vision=True,
|
||||
),
|
||||
"gemini-2.0-flash": ModelConfig(
|
||||
name="gemini-2.0-flash",
|
||||
litellm_name="gemini/gemini-2.0-flash",
|
||||
provider=Provider.GOOGLE,
|
||||
cost_per_1m_input=0.1,
|
||||
cost_per_1m_output=0.4,
|
||||
context_window=1048576,
|
||||
max_output_tokens=8192,
|
||||
supports_vision=True,
|
||||
),
|
||||
# Alibaba models
|
||||
"qwen-max": ModelConfig(
|
||||
name="qwen-max",
|
||||
litellm_name="alibaba/qwen-max",
|
||||
provider=Provider.ALIBABA,
|
||||
cost_per_1m_input=2.0,
|
||||
cost_per_1m_output=6.0,
|
||||
context_window=32768,
|
||||
max_output_tokens=8192,
|
||||
supports_vision=False,
|
||||
),
|
||||
# DeepSeek models
|
||||
"deepseek-coder": ModelConfig(
|
||||
name="deepseek-coder",
|
||||
litellm_name="deepseek/deepseek-coder",
|
||||
provider=Provider.DEEPSEEK,
|
||||
cost_per_1m_input=0.14,
|
||||
cost_per_1m_output=0.28,
|
||||
context_window=128000,
|
||||
max_output_tokens=8192,
|
||||
supports_vision=False,
|
||||
),
|
||||
"deepseek-chat": ModelConfig(
|
||||
name="deepseek-chat",
|
||||
litellm_name="deepseek/deepseek-chat",
|
||||
provider=Provider.DEEPSEEK,
|
||||
cost_per_1m_input=0.14,
|
||||
cost_per_1m_output=0.28,
|
||||
context_window=128000,
|
||||
max_output_tokens=8192,
|
||||
supports_vision=False,
|
||||
),
|
||||
# Embedding models
|
||||
"text-embedding-3-large": ModelConfig(
|
||||
name="text-embedding-3-large",
|
||||
litellm_name="openai/text-embedding-3-large",
|
||||
provider=Provider.OPENAI,
|
||||
cost_per_1m_input=0.13,
|
||||
cost_per_1m_output=0.0,
|
||||
context_window=8191,
|
||||
max_output_tokens=0,
|
||||
supports_vision=False,
|
||||
supports_streaming=False,
|
||||
supports_function_calling=False,
|
||||
),
|
||||
"voyage-3": ModelConfig(
|
||||
name="voyage-3",
|
||||
litellm_name="voyage/voyage-3",
|
||||
provider=Provider.ANTHROPIC, # Voyage via Anthropic partnership
|
||||
cost_per_1m_input=0.06,
|
||||
cost_per_1m_output=0.0,
|
||||
context_window=32000,
|
||||
max_output_tokens=0,
|
||||
supports_vision=False,
|
||||
supports_streaming=False,
|
||||
supports_function_calling=False,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelGroupConfig:
|
||||
"""Configuration for a model group with failover chain."""
|
||||
|
||||
primary: str # Primary model name
|
||||
fallbacks: list[str] # Fallback model names in order
|
||||
description: str
|
||||
|
||||
def get_all_models(self) -> list[str]:
|
||||
"""Get all models in priority order."""
|
||||
return [self.primary, *self.fallbacks]
|
||||
|
||||
|
||||
# Model group configurations per ADR-004
|
||||
MODEL_GROUPS: dict[ModelGroup, ModelGroupConfig] = {
|
||||
ModelGroup.REASONING: ModelGroupConfig(
|
||||
primary="claude-opus-4",
|
||||
fallbacks=["gpt-4.1", "gemini-2.5-pro", "qwen-max"],
|
||||
description="Complex analysis, architecture decisions",
|
||||
),
|
||||
ModelGroup.CODE: ModelGroupConfig(
|
||||
primary="claude-sonnet-4",
|
||||
fallbacks=["gpt-4.1", "deepseek-coder"],
|
||||
description="Code writing and refactoring",
|
||||
),
|
||||
ModelGroup.FAST: ModelGroupConfig(
|
||||
primary="claude-haiku",
|
||||
fallbacks=["gpt-4.1-mini", "gemini-2.0-flash"],
|
||||
description="Quick tasks, simple queries",
|
||||
),
|
||||
ModelGroup.VISION: ModelGroupConfig(
|
||||
primary="claude-sonnet-4",
|
||||
fallbacks=["gpt-4.1", "gemini-2.5-pro"],
|
||||
description="Multimodal image analysis",
|
||||
),
|
||||
ModelGroup.EMBEDDING: ModelGroupConfig(
|
||||
primary="text-embedding-3-large",
|
||||
fallbacks=["voyage-3"],
|
||||
description="Vector embeddings",
|
||||
),
|
||||
ModelGroup.COST_OPTIMIZED: ModelGroupConfig(
|
||||
primary="qwen-max",
|
||||
fallbacks=["deepseek-chat"],
|
||||
description="High-volume, non-critical tasks",
|
||||
),
|
||||
ModelGroup.SELF_HOSTED: ModelGroupConfig(
|
||||
primary="deepseek-chat",
|
||||
fallbacks=["qwen-max"],
|
||||
description="Privacy-sensitive, air-gapped",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Agent type to model group mapping per ADR-004
|
||||
AGENT_TYPE_MODEL_PREFERENCES: dict[str, ModelGroup] = {
|
||||
"product_owner": ModelGroup.REASONING,
|
||||
"software_architect": ModelGroup.REASONING,
|
||||
"software_engineer": ModelGroup.CODE,
|
||||
"qa_engineer": ModelGroup.CODE,
|
||||
"devops_engineer": ModelGroup.FAST,
|
||||
"project_manager": ModelGroup.FAST,
|
||||
"business_analyst": ModelGroup.REASONING,
|
||||
}
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""A single chat message."""
|
||||
|
||||
role: str = Field(..., description="Message role: system, user, assistant, tool")
|
||||
content: str | list[dict[str, Any]] = Field(..., description="Message content")
|
||||
name: str | None = Field(default=None, description="Optional name for the message")
|
||||
tool_calls: list[dict[str, Any]] | None = Field(
|
||||
default=None, description="Tool calls if role is assistant"
|
||||
)
|
||||
tool_call_id: str | None = Field(
|
||||
default=None, description="Tool call ID if role is tool"
|
||||
)
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""Request for chat completion."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for cost attribution")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
messages: list[ChatMessage] = Field(..., description="Chat messages")
|
||||
model_group: ModelGroup = Field(
|
||||
default=ModelGroup.REASONING, description="Model group for routing"
|
||||
)
|
||||
model_override: str | None = Field(
|
||||
default=None, description="Specific model to use (bypasses routing)"
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096, ge=1, le=32768, description="Max output tokens"
|
||||
)
|
||||
temperature: float = Field(
|
||||
default=0.7, ge=0.0, le=2.0, description="Sampling temperature"
|
||||
)
|
||||
stream: bool = Field(default=False, description="Enable streaming response")
|
||||
session_id: str | None = Field(
|
||||
default=None, description="Session ID for conversation tracking"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
"""Token usage statistics."""
|
||||
|
||||
prompt_tokens: int = Field(default=0, description="Input tokens used")
|
||||
completion_tokens: int = Field(default=0, description="Output tokens generated")
|
||||
total_tokens: int = Field(default=0, description="Total tokens")
|
||||
cost_usd: float = Field(default=0.0, description="Estimated cost in USD")
|
||||
|
||||
@classmethod
|
||||
def from_response(
|
||||
cls, prompt_tokens: int, completion_tokens: int, model_config: ModelConfig
|
||||
) -> "UsageStats":
|
||||
"""Create usage stats from token counts and model config."""
|
||||
input_cost = (prompt_tokens / 1_000_000) * model_config.cost_per_1m_input
|
||||
output_cost = (completion_tokens / 1_000_000) * model_config.cost_per_1m_output
|
||||
return cls(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
cost_usd=round(input_cost + output_cost, 6),
|
||||
)
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Response from chat completion."""
|
||||
|
||||
id: str = Field(..., description="Unique response ID")
|
||||
model: str = Field(..., description="Model that generated the response")
|
||||
provider: str = Field(..., description="Provider used")
|
||||
content: str = Field(..., description="Generated content")
|
||||
finish_reason: str = Field(
|
||||
default="stop", description="Reason for completion finish"
|
||||
)
|
||||
usage: UsageStats = Field(default_factory=UsageStats, description="Token usage")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(UTC), description="Response timestamp"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class StreamChunk(BaseModel):
|
||||
"""A chunk from a streaming response."""
|
||||
|
||||
id: str = Field(..., description="Chunk ID")
|
||||
delta: str = Field(default="", description="Content delta")
|
||||
finish_reason: str | None = Field(default=None, description="Finish reason if done")
|
||||
usage: UsageStats | None = Field(
|
||||
default=None, description="Usage stats (only on final chunk)"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
"""Request for text embeddings."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for cost attribution")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
texts: list[str] = Field(..., min_length=1, description="Texts to embed")
|
||||
model: str = Field(default="text-embedding-3-large", description="Embedding model")
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
"""Response from embedding generation."""
|
||||
|
||||
model: str = Field(..., description="Model used")
|
||||
embeddings: list[list[float]] = Field(..., description="Embedding vectors")
|
||||
usage: UsageStats = Field(default_factory=UsageStats, description="Token usage")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostRecord:
|
||||
"""A single cost record for tracking."""
|
||||
|
||||
project_id: str
|
||||
agent_id: str
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
cost_usd: float
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
session_id: str | None = None
|
||||
request_id: str | None = None
|
||||
|
||||
|
||||
class UsageReport(BaseModel):
|
||||
"""Usage report for a project or agent."""
|
||||
|
||||
entity_id: str = Field(..., description="Project or agent ID")
|
||||
entity_type: str = Field(..., description="'project' or 'agent'")
|
||||
period: str = Field(..., description="Report period (hour, day, month)")
|
||||
period_start: datetime = Field(..., description="Period start time")
|
||||
period_end: datetime = Field(..., description="Period end time")
|
||||
total_requests: int = Field(default=0, description="Total requests")
|
||||
total_tokens: int = Field(default=0, description="Total tokens used")
|
||||
total_cost_usd: float = Field(default=0.0, description="Total cost in USD")
|
||||
by_model: dict[str, dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Breakdown by model"
|
||||
)
|
||||
by_agent: dict[str, dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Breakdown by agent (for project reports)"
|
||||
)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Information about an available model."""
|
||||
|
||||
name: str = Field(..., description="Model name")
|
||||
provider: str = Field(..., description="Provider name")
|
||||
cost_per_1m_input: float = Field(..., description="Input cost per 1M tokens")
|
||||
cost_per_1m_output: float = Field(..., description="Output cost per 1M tokens")
|
||||
context_window: int = Field(..., description="Max context tokens")
|
||||
max_output_tokens: int = Field(..., description="Max output tokens")
|
||||
supports_vision: bool = Field(..., description="Vision capability")
|
||||
supports_streaming: bool = Field(..., description="Streaming capability")
|
||||
supports_function_calling: bool = Field(..., description="Function calling")
|
||||
available: bool = Field(default=True, description="Provider configured")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ModelConfig, available: bool = True) -> "ModelInfo":
|
||||
"""Create ModelInfo from ModelConfig."""
|
||||
return cls(
|
||||
name=config.name,
|
||||
provider=config.provider.value,
|
||||
cost_per_1m_input=config.cost_per_1m_input,
|
||||
cost_per_1m_output=config.cost_per_1m_output,
|
||||
context_window=config.context_window,
|
||||
max_output_tokens=config.max_output_tokens,
|
||||
supports_vision=config.supports_vision,
|
||||
supports_streaming=config.supports_streaming,
|
||||
supports_function_calling=config.supports_function_calling,
|
||||
available=available,
|
||||
)
|
||||
|
||||
|
||||
class ModelGroupInfo(BaseModel):
|
||||
"""Information about a model group."""
|
||||
|
||||
name: str = Field(..., description="Group name")
|
||||
description: str = Field(..., description="Group description")
|
||||
primary_model: str = Field(..., description="Primary model")
|
||||
fallback_models: list[str] = Field(..., description="Fallback models")
|
||||
available: bool = Field(default=True, description="At least one model available")
|
||||
328
mcp-servers/llm-gateway/providers.py
Normal file
328
mcp-servers/llm-gateway/providers.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
LiteLLM provider configuration for LLM Gateway.
|
||||
|
||||
Configures the LiteLLM Router with model lists and failover chains.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
from config import Settings, get_settings
|
||||
from models import (
|
||||
MODEL_CONFIGS,
|
||||
MODEL_GROUPS,
|
||||
ModelConfig,
|
||||
ModelGroup,
|
||||
Provider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def configure_litellm(settings: Settings) -> None:
|
||||
"""
|
||||
Configure LiteLLM global settings.
|
||||
|
||||
Args:
|
||||
settings: Application settings
|
||||
"""
|
||||
# Set API keys in environment (LiteLLM reads from env)
|
||||
if settings.anthropic_api_key:
|
||||
os.environ["ANTHROPIC_API_KEY"] = settings.anthropic_api_key
|
||||
if settings.openai_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = settings.openai_api_key
|
||||
if settings.google_api_key:
|
||||
os.environ["GEMINI_API_KEY"] = settings.google_api_key
|
||||
if settings.alibaba_api_key:
|
||||
os.environ["DASHSCOPE_API_KEY"] = settings.alibaba_api_key
|
||||
if settings.deepseek_api_key:
|
||||
os.environ["DEEPSEEK_API_KEY"] = settings.deepseek_api_key
|
||||
|
||||
# Configure LiteLLM settings
|
||||
litellm.drop_params = True # Drop unsupported params instead of erroring
|
||||
litellm.set_verbose = settings.debug
|
||||
|
||||
# Configure caching if enabled
|
||||
if settings.litellm_cache_enabled:
|
||||
litellm.cache = litellm.Cache(
|
||||
type="redis", # type: ignore[arg-type]
|
||||
host=_parse_redis_host(settings.redis_url),
|
||||
port=_parse_redis_port(settings.redis_url), # type: ignore[arg-type]
|
||||
ttl=settings.litellm_cache_ttl,
|
||||
)
|
||||
|
||||
|
||||
def _parse_redis_host(redis_url: str) -> str:
|
||||
"""Extract host from Redis URL."""
|
||||
# redis://host:port/db
|
||||
url = redis_url.replace("redis://", "")
|
||||
host_port = url.split("/")[0]
|
||||
return host_port.split(":")[0]
|
||||
|
||||
|
||||
def _parse_redis_port(redis_url: str) -> int:
|
||||
"""Extract port from Redis URL."""
|
||||
url = redis_url.replace("redis://", "")
|
||||
host_port = url.split("/")[0]
|
||||
parts = host_port.split(":")
|
||||
return int(parts[1]) if len(parts) > 1 else 6379
|
||||
|
||||
|
||||
def _is_provider_available(provider: Provider, settings: Settings) -> bool:
|
||||
"""Check if a provider is available (API key configured)."""
|
||||
provider_key_map = {
|
||||
Provider.ANTHROPIC: settings.anthropic_api_key,
|
||||
Provider.OPENAI: settings.openai_api_key,
|
||||
Provider.GOOGLE: settings.google_api_key,
|
||||
Provider.ALIBABA: settings.alibaba_api_key,
|
||||
Provider.DEEPSEEK: settings.deepseek_api_key or settings.deepseek_base_url,
|
||||
}
|
||||
return bool(provider_key_map.get(provider))
|
||||
|
||||
|
||||
def _build_model_entry(
|
||||
model_config: ModelConfig,
|
||||
settings: Settings,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Build a model entry for LiteLLM Router.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration
|
||||
settings: Application settings
|
||||
|
||||
Returns:
|
||||
Model entry dict or None if provider unavailable
|
||||
"""
|
||||
if not _is_provider_available(model_config.provider, settings):
|
||||
logger.debug(
|
||||
f"Skipping model {model_config.name}: "
|
||||
f"{model_config.provider.value} provider not configured"
|
||||
)
|
||||
return None
|
||||
|
||||
entry: dict[str, Any] = {
|
||||
"model_name": model_config.name,
|
||||
"litellm_params": {
|
||||
"model": model_config.litellm_name,
|
||||
"timeout": settings.litellm_timeout,
|
||||
"max_retries": settings.litellm_max_retries,
|
||||
},
|
||||
}
|
||||
|
||||
# Add custom base URL for DeepSeek self-hosted
|
||||
if model_config.provider == Provider.DEEPSEEK and settings.deepseek_base_url:
|
||||
entry["litellm_params"]["api_base"] = settings.deepseek_base_url
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def build_model_list(settings: Settings | None = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Build the complete model list for LiteLLM Router.
|
||||
|
||||
Args:
|
||||
settings: Application settings (uses default if None)
|
||||
|
||||
Returns:
|
||||
List of model entries for Router
|
||||
"""
|
||||
if settings is None:
|
||||
settings = get_settings()
|
||||
|
||||
model_list: list[dict[str, Any]] = []
|
||||
|
||||
for model_config in MODEL_CONFIGS.values():
|
||||
entry = _build_model_entry(model_config, settings)
|
||||
if entry:
|
||||
model_list.append(entry)
|
||||
|
||||
logger.info(f"Built model list with {len(model_list)} models")
|
||||
return model_list
|
||||
|
||||
|
||||
def build_fallback_config(settings: Settings | None = None) -> dict[str, list[str]]:
|
||||
"""
|
||||
Build fallback configuration based on model groups.
|
||||
|
||||
Args:
|
||||
settings: Application settings (uses default if None)
|
||||
|
||||
Returns:
|
||||
Dict mapping model names to their fallback chains
|
||||
"""
|
||||
if settings is None:
|
||||
settings = get_settings()
|
||||
|
||||
fallbacks: dict[str, list[str]] = {}
|
||||
|
||||
for _group, config in MODEL_GROUPS.items():
|
||||
# Get available models in this group's chain
|
||||
available_models = []
|
||||
for model_name in config.get_all_models():
|
||||
model_config = MODEL_CONFIGS.get(model_name)
|
||||
if model_config and _is_provider_available(model_config.provider, settings):
|
||||
available_models.append(model_name)
|
||||
|
||||
if len(available_models) > 1:
|
||||
# First model falls back to remaining models
|
||||
fallbacks[available_models[0]] = available_models[1:]
|
||||
|
||||
return fallbacks
|
||||
|
||||
|
||||
def get_available_models(settings: Settings | None = None) -> dict[str, ModelConfig]:
|
||||
"""
|
||||
Get all available models (with configured providers).
|
||||
|
||||
Args:
|
||||
settings: Application settings (uses default if None)
|
||||
|
||||
Returns:
|
||||
Dict of available model configs
|
||||
"""
|
||||
if settings is None:
|
||||
settings = get_settings()
|
||||
|
||||
available: dict[str, ModelConfig] = {}
|
||||
|
||||
for name, config in MODEL_CONFIGS.items():
|
||||
if _is_provider_available(config.provider, settings):
|
||||
available[name] = config
|
||||
|
||||
return available
|
||||
|
||||
|
||||
def get_available_model_groups(
|
||||
settings: Settings | None = None,
|
||||
) -> dict[ModelGroup, list[str]]:
|
||||
"""
|
||||
Get available models for each model group.
|
||||
|
||||
Args:
|
||||
settings: Application settings (uses default if None)
|
||||
|
||||
Returns:
|
||||
Dict mapping model groups to available models
|
||||
"""
|
||||
if settings is None:
|
||||
settings = get_settings()
|
||||
|
||||
result: dict[ModelGroup, list[str]] = {}
|
||||
|
||||
for group, config in MODEL_GROUPS.items():
|
||||
available_models = []
|
||||
for model_name in config.get_all_models():
|
||||
model_config = MODEL_CONFIGS.get(model_name)
|
||||
if model_config and _is_provider_available(model_config.provider, settings):
|
||||
available_models.append(model_name)
|
||||
result[group] = available_models
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class LLMProvider:
|
||||
"""
|
||||
LLM Provider wrapper around LiteLLM Router.
|
||||
|
||||
Provides a high-level interface for LLM operations with
|
||||
automatic failover and configuration management.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""
|
||||
Initialize LLM Provider.
|
||||
|
||||
Args:
|
||||
settings: Application settings (uses default if None)
|
||||
"""
|
||||
self._settings = settings or get_settings()
|
||||
self._router: Router | None = None
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the provider and LiteLLM Router."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Configure LiteLLM global settings
|
||||
configure_litellm(self._settings)
|
||||
|
||||
# Build model list
|
||||
model_list = build_model_list(self._settings)
|
||||
|
||||
if not model_list:
|
||||
logger.warning("No models available - no providers configured")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
# Build fallback config
|
||||
fallbacks = build_fallback_config(self._settings)
|
||||
|
||||
# Create Router
|
||||
self._router = Router(
|
||||
model_list=model_list,
|
||||
fallbacks=list(fallbacks.items()) if fallbacks else None, # type: ignore[arg-type]
|
||||
routing_strategy="latency-based-routing",
|
||||
num_retries=self._settings.litellm_max_retries,
|
||||
timeout=self._settings.litellm_timeout,
|
||||
retry_after=5, # Retry after 5 seconds
|
||||
allowed_fails=2, # Fail after 2 consecutive failures
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
f"LLM Provider initialized with {len(model_list)} models, "
|
||||
f"{len(fallbacks)} fallback chains"
|
||||
)
|
||||
|
||||
@property
|
||||
def router(self) -> Router | None:
|
||||
"""Get the LiteLLM Router."""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._router
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if provider is available."""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._router is not None
|
||||
|
||||
def get_model_config(self, model_name: str) -> ModelConfig | None:
|
||||
"""Get configuration for a specific model."""
|
||||
return MODEL_CONFIGS.get(model_name)
|
||||
|
||||
def get_available_models(self) -> dict[str, ModelConfig]:
|
||||
"""Get all available models."""
|
||||
return get_available_models(self._settings)
|
||||
|
||||
def is_model_available(self, model_name: str) -> bool:
|
||||
"""Check if a specific model is available."""
|
||||
model_config = MODEL_CONFIGS.get(model_name)
|
||||
if not model_config:
|
||||
return False
|
||||
return _is_provider_available(model_config.provider, self._settings)
|
||||
|
||||
|
||||
# Global provider instance (lazy initialization)
|
||||
_provider: LLMProvider | None = None
|
||||
|
||||
|
||||
def get_provider() -> LLMProvider:
|
||||
"""Get the global LLM Provider instance."""
|
||||
global _provider
|
||||
if _provider is None:
|
||||
_provider = LLMProvider()
|
||||
return _provider
|
||||
|
||||
|
||||
def reset_provider() -> None:
|
||||
"""Reset the global provider (for testing)."""
|
||||
global _provider
|
||||
_provider = None
|
||||
@@ -4,20 +4,101 @@ version = "0.1.0"
|
||||
description = "Syndarix LLM Gateway MCP Server - Unified LLM access with failover and cost tracking"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastmcp>=0.1.0",
|
||||
"fastmcp>=2.0.0",
|
||||
"litellm>=1.50.0",
|
||||
"redis>=5.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"tiktoken>=0.7.0",
|
||||
"httpx>=0.27.0",
|
||||
"uvicorn>=0.30.0",
|
||||
"fastapi>=0.115.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=5.0.0",
|
||||
"respx>=0.21.0",
|
||||
"fakeredis>=2.25.0",
|
||||
"ruff>=0.8.0",
|
||||
"mypy>=1.11.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
llm-gateway = "server:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["."]
|
||||
exclude = ["tests/", "*.md", "Dockerfile"]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
include = ["*.py", "pyproject.toml"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
"ARG", # flake8-unused-arguments
|
||||
"SIM", # flake8-simplify
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
"B904", # raise from in except (too noisy)
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["config", "models", "exceptions", "providers", "failover", "routing", "cost_tracking", "streaming"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
testpaths = ["tests"]
|
||||
addopts = "-v --tb=short"
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["."]
|
||||
omit = ["tests/*", "conftest.py"]
|
||||
branch = true
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"raise NotImplementedError",
|
||||
"if TYPE_CHECKING:",
|
||||
"if __name__ == .__main__.:",
|
||||
]
|
||||
fail_under = 90
|
||||
show_missing = true
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
warn_return_any = false
|
||||
warn_unused_ignores = false
|
||||
disallow_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "tests.*"
|
||||
disallow_untyped_defs = false
|
||||
ignore_errors = true
|
||||
|
||||
321
mcp-servers/llm-gateway/routing.py
Normal file
321
mcp-servers/llm-gateway/routing.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Model routing for LLM Gateway.
|
||||
|
||||
Handles model selection based on:
|
||||
- Model group configuration
|
||||
- Circuit breaker availability
|
||||
- Agent type preferences
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
AllProvidersFailedError,
|
||||
InvalidModelError,
|
||||
InvalidModelGroupError,
|
||||
ModelNotAvailableError,
|
||||
)
|
||||
from failover import CircuitBreakerRegistry, get_circuit_registry
|
||||
from models import (
|
||||
AGENT_TYPE_MODEL_PREFERENCES,
|
||||
MODEL_CONFIGS,
|
||||
MODEL_GROUPS,
|
||||
ModelConfig,
|
||||
ModelGroup,
|
||||
)
|
||||
from providers import get_available_models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
"""
|
||||
Routes requests to appropriate models based on configuration.
|
||||
|
||||
Considers:
|
||||
- Model group preferences
|
||||
- Circuit breaker states
|
||||
- Agent type defaults
|
||||
- Provider availability
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings | None = None,
|
||||
circuit_registry: CircuitBreakerRegistry | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model router.
|
||||
|
||||
Args:
|
||||
settings: Application settings
|
||||
circuit_registry: Circuit breaker registry
|
||||
"""
|
||||
self._settings = settings or get_settings()
|
||||
self._circuit_registry = circuit_registry or get_circuit_registry()
|
||||
|
||||
def parse_model_group(self, group_str: str) -> ModelGroup:
|
||||
"""
|
||||
Parse model group from string.
|
||||
|
||||
Args:
|
||||
group_str: Group name string
|
||||
|
||||
Returns:
|
||||
ModelGroup enum value
|
||||
|
||||
Raises:
|
||||
InvalidModelGroupError: If group is unknown
|
||||
"""
|
||||
# Handle aliases
|
||||
aliases = {
|
||||
"high-reasoning": ModelGroup.REASONING,
|
||||
"high_reasoning": ModelGroup.REASONING,
|
||||
"code-generation": ModelGroup.CODE,
|
||||
"code_generation": ModelGroup.CODE,
|
||||
"fast-response": ModelGroup.FAST,
|
||||
"fast_response": ModelGroup.FAST,
|
||||
}
|
||||
|
||||
# Try direct enum value
|
||||
try:
|
||||
return ModelGroup(group_str.lower())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try aliases
|
||||
if group_str.lower() in aliases:
|
||||
return aliases[group_str.lower()]
|
||||
|
||||
# Unknown group
|
||||
available = [g.value for g in ModelGroup]
|
||||
raise InvalidModelGroupError(
|
||||
model_group=group_str,
|
||||
available_groups=available,
|
||||
)
|
||||
|
||||
def get_model_config(self, model_name: str) -> ModelConfig:
|
||||
"""
|
||||
Get configuration for a specific model.
|
||||
|
||||
Args:
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Model configuration
|
||||
|
||||
Raises:
|
||||
InvalidModelError: If model is unknown
|
||||
"""
|
||||
config = MODEL_CONFIGS.get(model_name)
|
||||
if not config:
|
||||
raise InvalidModelError(
|
||||
model=model_name,
|
||||
reason="Unknown model",
|
||||
)
|
||||
return config
|
||||
|
||||
def get_preferred_group_for_agent(self, agent_type: str) -> ModelGroup:
|
||||
"""
|
||||
Get preferred model group for an agent type.
|
||||
|
||||
Args:
|
||||
agent_type: Agent type identifier
|
||||
|
||||
Returns:
|
||||
Preferred ModelGroup
|
||||
"""
|
||||
return AGENT_TYPE_MODEL_PREFERENCES.get(
|
||||
agent_type.lower(),
|
||||
ModelGroup.REASONING, # Default to reasoning
|
||||
)
|
||||
|
||||
async def select_model(
|
||||
self,
|
||||
model_group: ModelGroup | str,
|
||||
model_override: str | None = None,
|
||||
agent_type: str | None = None,
|
||||
) -> tuple[str, ModelConfig]:
|
||||
"""
|
||||
Select the best available model.
|
||||
|
||||
Args:
|
||||
model_group: Desired model group
|
||||
model_override: Specific model to use (bypasses group routing)
|
||||
agent_type: Agent type for preference lookup
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, model_config)
|
||||
|
||||
Raises:
|
||||
InvalidModelError: If override model is invalid
|
||||
InvalidModelGroupError: If group is invalid
|
||||
AllProvidersFailedError: If no models are available
|
||||
"""
|
||||
# Handle model override
|
||||
if model_override:
|
||||
config = MODEL_CONFIGS.get(model_override)
|
||||
if not config:
|
||||
raise InvalidModelError(
|
||||
model=model_override,
|
||||
reason="Unknown model",
|
||||
)
|
||||
# Check if model's provider is available (using router's settings)
|
||||
available_models = get_available_models(self._settings)
|
||||
if model_override not in available_models:
|
||||
raise ModelNotAvailableError(
|
||||
model=model_override,
|
||||
provider=config.provider.value,
|
||||
)
|
||||
# Check circuit breaker
|
||||
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
|
||||
if not circuit.is_available():
|
||||
raise ModelNotAvailableError(
|
||||
model=model_override,
|
||||
provider=f"{config.provider.value} (circuit open)",
|
||||
)
|
||||
return model_override, config
|
||||
|
||||
# Parse model group if string
|
||||
if isinstance(model_group, str):
|
||||
model_group = self.parse_model_group(model_group)
|
||||
|
||||
# Get agent type preference if no explicit group
|
||||
if agent_type:
|
||||
preferred = self.get_preferred_group_for_agent(agent_type)
|
||||
logger.debug(
|
||||
f"Agent type {agent_type} prefers {preferred.value}, "
|
||||
f"requested {model_group.value}"
|
||||
)
|
||||
|
||||
# Get group configuration
|
||||
group_config = MODEL_GROUPS.get(model_group)
|
||||
if not group_config:
|
||||
raise InvalidModelGroupError(
|
||||
model_group=model_group.value,
|
||||
available_groups=[g.value for g in ModelGroup],
|
||||
)
|
||||
|
||||
# Get available models
|
||||
available_models = get_available_models(self._settings)
|
||||
|
||||
# Try models in priority order
|
||||
errors: list[dict[str, Any]] = []
|
||||
attempted: list[str] = []
|
||||
|
||||
for model_name in group_config.get_all_models():
|
||||
attempted.append(model_name)
|
||||
|
||||
# Check if model provider is configured
|
||||
config = MODEL_CONFIGS.get(model_name)
|
||||
if not config:
|
||||
errors.append({"model": model_name, "error": "Unknown model"})
|
||||
continue
|
||||
|
||||
if model_name not in available_models:
|
||||
errors.append(
|
||||
{
|
||||
"model": model_name,
|
||||
"error": f"Provider {config.provider.value} not configured",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check circuit breaker
|
||||
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
|
||||
if not circuit.is_available():
|
||||
errors.append(
|
||||
{
|
||||
"model": model_name,
|
||||
"error": f"Circuit open for {config.provider.value}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Model is available
|
||||
logger.debug(f"Selected model {model_name} for group {model_group.value}")
|
||||
return model_name, config
|
||||
|
||||
# No models available
|
||||
raise AllProvidersFailedError(
|
||||
model_group=model_group.value,
|
||||
attempted_models=attempted,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def get_available_models_for_group(
|
||||
self,
|
||||
model_group: ModelGroup | str,
|
||||
) -> list[tuple[str, ModelConfig, bool]]:
|
||||
"""
|
||||
Get all models for a group with availability status.
|
||||
|
||||
Args:
|
||||
model_group: Model group
|
||||
|
||||
Returns:
|
||||
List of (model_name, config, is_available) tuples
|
||||
"""
|
||||
# Parse model group if string
|
||||
if isinstance(model_group, str):
|
||||
model_group = self.parse_model_group(model_group)
|
||||
|
||||
group_config = MODEL_GROUPS.get(model_group)
|
||||
if not group_config:
|
||||
return []
|
||||
|
||||
available_models = get_available_models(self._settings)
|
||||
result: list[tuple[str, ModelConfig, bool]] = []
|
||||
|
||||
for model_name in group_config.get_all_models():
|
||||
config = MODEL_CONFIGS.get(model_name)
|
||||
if not config:
|
||||
continue
|
||||
|
||||
is_available = model_name in available_models
|
||||
if is_available:
|
||||
# Also check circuit breaker
|
||||
circuit = self._circuit_registry.get_circuit_sync(config.provider.value)
|
||||
is_available = circuit.is_available()
|
||||
|
||||
result.append((model_name, config, is_available))
|
||||
|
||||
return result
|
||||
|
||||
def get_all_model_groups(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Get information about all model groups.
|
||||
|
||||
Returns:
|
||||
Dict of group info
|
||||
"""
|
||||
result: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for group, config in MODEL_GROUPS.items():
|
||||
result[group.value] = {
|
||||
"description": config.description,
|
||||
"primary": config.primary,
|
||||
"fallbacks": config.fallbacks,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Global router instance (lazy initialization)
|
||||
_router: ModelRouter | None = None
|
||||
|
||||
|
||||
def get_model_router() -> ModelRouter:
|
||||
"""Get the global model router instance."""
|
||||
global _router
|
||||
if _router is None:
|
||||
_router = ModelRouter()
|
||||
return _router
|
||||
|
||||
|
||||
def reset_model_router() -> None:
|
||||
"""Reset the global router (for testing)."""
|
||||
global _router
|
||||
_router = None
|
||||
@@ -4,36 +4,559 @@ Syndarix LLM Gateway MCP Server.
|
||||
Provides unified LLM access with:
|
||||
- Multi-provider support (Claude, GPT, Gemini, Qwen, DeepSeek)
|
||||
- Automatic failover chains
|
||||
- Cost tracking via LiteLLM callbacks
|
||||
- Model group routing (high-reasoning, code-generation, fast-response, cost-optimized)
|
||||
- Cost tracking via Redis
|
||||
- Model group routing (reasoning, code, fast, vision, embedding)
|
||||
- Circuit breaker protection
|
||||
|
||||
Per ADR-004: LLM Provider Abstraction.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from fastapi import FastAPI
|
||||
from fastmcp import FastMCP
|
||||
|
||||
# Create MCP server
|
||||
mcp = FastMCP(
|
||||
"syndarix-llm-gateway",
|
||||
description="Unified LLM access with failover and cost tracking",
|
||||
from config import get_settings
|
||||
from cost_tracking import calculate_cost, get_cost_tracker
|
||||
from exceptions import (
|
||||
AllProvidersFailedError,
|
||||
CircuitOpenError,
|
||||
CostLimitExceededError,
|
||||
InvalidModelError,
|
||||
InvalidModelGroupError,
|
||||
LLMGatewayError,
|
||||
ModelNotAvailableError,
|
||||
)
|
||||
from failover import get_circuit_registry
|
||||
from models import (
|
||||
MODEL_CONFIGS,
|
||||
MODEL_GROUPS,
|
||||
ModelGroup,
|
||||
)
|
||||
from providers import get_provider
|
||||
from routing import get_model_router
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Create FastMCP server
|
||||
mcp = FastMCP("syndarix-llm-gateway")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
"""Application lifespan handler."""
|
||||
settings = get_settings()
|
||||
logger.info(f"Starting LLM Gateway on {settings.host}:{settings.port}")
|
||||
|
||||
# Initialize provider
|
||||
provider = get_provider()
|
||||
provider.initialize()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
from cost_tracking import close_cost_tracker
|
||||
|
||||
await close_cost_tracker()
|
||||
logger.info("LLM Gateway shutdown complete")
|
||||
|
||||
|
||||
# Create FastAPI app that wraps FastMCP
|
||||
app = FastAPI(
|
||||
title="Syndarix LLM Gateway",
|
||||
description="MCP Server for unified LLM access",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
# Health endpoint
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, Any]:
|
||||
"""Health check endpoint."""
|
||||
settings = get_settings()
|
||||
provider = get_provider()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "llm-gateway",
|
||||
"providers_configured": settings.get_available_providers(),
|
||||
"provider_available": provider.is_available,
|
||||
}
|
||||
|
||||
|
||||
# Tool discovery endpoint (for MCP client compatibility)
|
||||
@app.get("/mcp/tools")
|
||||
async def list_tools() -> dict[str, Any]:
|
||||
"""List available MCP tools."""
|
||||
return {
|
||||
"tools": [
|
||||
{
|
||||
"name": "chat_completion",
|
||||
"description": "Generate a chat completion using the specified model group",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_id": {"type": "string", "description": "Project ID"},
|
||||
"agent_id": {"type": "string", "description": "Agent ID"},
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
},
|
||||
"required": ["role", "content"],
|
||||
},
|
||||
},
|
||||
"model_group": {
|
||||
"type": "string",
|
||||
"enum": [g.value for g in ModelGroup],
|
||||
"default": "reasoning",
|
||||
},
|
||||
"max_tokens": {"type": "integer", "default": 4096},
|
||||
"temperature": {"type": "number", "default": 0.7},
|
||||
"stream": {"type": "boolean", "default": False},
|
||||
},
|
||||
"required": ["project_id", "agent_id", "messages"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "list_models",
|
||||
"description": "List available models and model groups",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_id": {"type": "string"},
|
||||
"agent_id": {"type": "string"},
|
||||
"model_group": {"type": "string"},
|
||||
},
|
||||
"required": ["project_id", "agent_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get_usage",
|
||||
"description": "Get usage statistics for a project or agent",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_id": {"type": "string"},
|
||||
"agent_id": {"type": "string"},
|
||||
"period": {
|
||||
"type": "string",
|
||||
"enum": ["hour", "day", "month"],
|
||||
"default": "day",
|
||||
},
|
||||
},
|
||||
"required": ["project_id", "agent_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "count_tokens",
|
||||
"description": "Count tokens in text",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_id": {"type": "string"},
|
||||
"agent_id": {"type": "string"},
|
||||
"text": {"type": "string"},
|
||||
"model": {"type": "string"},
|
||||
},
|
||||
"required": ["project_id", "agent_id", "text"],
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# JSON-RPC endpoint (for MCP client compatibility)
|
||||
@app.post("/mcp")
|
||||
async def jsonrpc_handler(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Handle JSON-RPC 2.0 requests for MCP tools."""
|
||||
# Validate JSON-RPC structure
|
||||
if request.get("jsonrpc") != "2.0":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32600, "message": "Invalid JSON-RPC version"},
|
||||
"id": request.get("id"),
|
||||
}
|
||||
|
||||
method = request.get("method")
|
||||
params = request.get("params", {})
|
||||
request_id = request.get("id")
|
||||
|
||||
# Handle tool calls
|
||||
if method == "tools/call":
|
||||
tool_name = params.get("name")
|
||||
tool_args = params.get("arguments", {})
|
||||
|
||||
try:
|
||||
if tool_name == "chat_completion":
|
||||
result = await _impl_chat_completion(**tool_args)
|
||||
elif tool_name == "list_models":
|
||||
result = await _impl_list_models(**tool_args)
|
||||
elif tool_name == "get_usage":
|
||||
result = await _impl_get_usage(**tool_args)
|
||||
elif tool_name == "count_tokens":
|
||||
result = await _impl_count_tokens(**tool_args)
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"result": {"content": [{"type": "text", "text": str(result)}]},
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
except LLMGatewayError as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32000, "message": str(e), "data": e.to_dict()},
|
||||
"id": request_id,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing tool {tool_name}")
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32603, "message": str(e)},
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
# Handle tool listing
|
||||
elif method == "tools/list":
|
||||
tools_response = await list_tools()
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"result": tools_response,
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32601, "message": f"Unknown method: {method}"},
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Core Implementation Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def _impl_chat_completion(
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
model_group: str = "reasoning",
|
||||
model_override: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Core implementation for chat completion."""
|
||||
settings = get_settings()
|
||||
router = get_model_router()
|
||||
tracker = get_cost_tracker()
|
||||
circuit_registry = get_circuit_registry()
|
||||
|
||||
# Check budget before making request
|
||||
if settings.cost_tracking_enabled:
|
||||
within_budget, current_cost, limit = await tracker.check_budget(project_id)
|
||||
if not within_budget:
|
||||
raise CostLimitExceededError(
|
||||
entity_type="project",
|
||||
entity_id=project_id,
|
||||
current_cost=current_cost,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Select model
|
||||
try:
|
||||
model_name, model_config = await router.select_model(
|
||||
model_group=model_group,
|
||||
model_override=model_override,
|
||||
)
|
||||
except (InvalidModelGroupError, InvalidModelError, AllProvidersFailedError):
|
||||
raise
|
||||
except ModelNotAvailableError:
|
||||
raise
|
||||
|
||||
# Get provider
|
||||
provider = get_provider()
|
||||
if not provider.router:
|
||||
raise AllProvidersFailedError(
|
||||
model_group=model_group,
|
||||
attempted_models=[model_name],
|
||||
errors=[{"error": "No providers configured"}],
|
||||
)
|
||||
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Get circuit breaker for this provider
|
||||
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
|
||||
|
||||
# Make completion request
|
||||
if stream:
|
||||
# Return streaming response info
|
||||
# Actual streaming would be handled by a separate endpoint
|
||||
return {
|
||||
"status": "streaming_not_supported_via_tool",
|
||||
"message": "Use /stream endpoint for streaming responses",
|
||||
"request_id": request_id,
|
||||
}
|
||||
|
||||
# Non-streaming completion
|
||||
response = await provider.router.acompletion(
|
||||
model=model_name,
|
||||
messages=messages, # type: ignore[arg-type]
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Record success
|
||||
await circuit.record_success()
|
||||
|
||||
# Extract response data
|
||||
content = response.choices[0].message.content or "" # type: ignore[union-attr]
|
||||
finish_reason = response.choices[0].finish_reason or "stop"
|
||||
|
||||
# Get usage stats
|
||||
prompt_tokens = response.usage.prompt_tokens if response.usage else 0 # type: ignore[attr-defined]
|
||||
completion_tokens = response.usage.completion_tokens if response.usage else 0 # type: ignore[attr-defined]
|
||||
|
||||
# Calculate cost
|
||||
cost_usd = calculate_cost(model_name, prompt_tokens, completion_tokens)
|
||||
|
||||
# Record usage
|
||||
await tracker.record_usage(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
model=model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cost_usd=cost_usd,
|
||||
session_id=session_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": request_id,
|
||||
"model": model_name,
|
||||
"provider": model_config.provider.value,
|
||||
"content": content,
|
||||
"finish_reason": finish_reason,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
"cost_usd": cost_usd,
|
||||
},
|
||||
}
|
||||
|
||||
except CircuitOpenError:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Record failure
|
||||
circuit = circuit_registry.get_circuit_sync(model_config.provider.value)
|
||||
await circuit.record_failure(e)
|
||||
|
||||
logger.error(f"Completion failed: {e}")
|
||||
raise AllProvidersFailedError(
|
||||
model_group=model_group,
|
||||
attempted_models=[model_name],
|
||||
errors=[{"model": model_name, "error": str(e)}],
|
||||
)
|
||||
|
||||
|
||||
async def _impl_list_models(
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
model_group: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Core implementation for list_models."""
|
||||
settings = get_settings()
|
||||
provider = get_provider()
|
||||
router = get_model_router()
|
||||
|
||||
# Get available providers
|
||||
available_providers = settings.get_available_providers()
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"available_providers": available_providers,
|
||||
}
|
||||
|
||||
if model_group:
|
||||
# List models for specific group
|
||||
try:
|
||||
parsed_group = router.parse_model_group(model_group)
|
||||
models = await router.get_available_models_for_group(parsed_group)
|
||||
|
||||
result["model_group"] = model_group
|
||||
result["models"] = [
|
||||
{
|
||||
"name": name,
|
||||
"provider": config.provider.value,
|
||||
"available": available,
|
||||
"cost_per_1m_input": config.cost_per_1m_input,
|
||||
"cost_per_1m_output": config.cost_per_1m_output,
|
||||
}
|
||||
for name, config, available in models
|
||||
]
|
||||
except InvalidModelGroupError as e:
|
||||
result["error"] = e.to_dict()
|
||||
else:
|
||||
# List all model groups
|
||||
groups: dict[str, Any] = {}
|
||||
for group in ModelGroup:
|
||||
group_config = MODEL_GROUPS.get(group)
|
||||
if group_config:
|
||||
models = await router.get_available_models_for_group(group)
|
||||
available_count = sum(1 for _, _, avail in models if avail)
|
||||
groups[group.value] = {
|
||||
"description": group_config.description,
|
||||
"primary": group_config.primary,
|
||||
"fallbacks": group_config.fallbacks,
|
||||
"available_models": available_count,
|
||||
"total_models": len(models),
|
||||
}
|
||||
result["model_groups"] = groups
|
||||
|
||||
# List all models
|
||||
all_models: list[dict[str, Any]] = []
|
||||
available_models = provider.get_available_models()
|
||||
for name, config in MODEL_CONFIGS.items():
|
||||
all_models.append(
|
||||
{
|
||||
"name": name,
|
||||
"provider": config.provider.value,
|
||||
"available": name in available_models,
|
||||
"cost_per_1m_input": config.cost_per_1m_input,
|
||||
"cost_per_1m_output": config.cost_per_1m_output,
|
||||
"context_window": config.context_window,
|
||||
"max_output_tokens": config.max_output_tokens,
|
||||
"supports_vision": config.supports_vision,
|
||||
"supports_streaming": config.supports_streaming,
|
||||
}
|
||||
)
|
||||
result["models"] = all_models
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _impl_get_usage(
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
period: str = "day",
|
||||
) -> dict[str, Any]:
|
||||
"""Core implementation for get_usage."""
|
||||
tracker = get_cost_tracker()
|
||||
|
||||
# Get project usage
|
||||
project_report = await tracker.get_project_usage(project_id, period=period)
|
||||
|
||||
# Get agent usage
|
||||
agent_report = await tracker.get_agent_usage(agent_id, period=period)
|
||||
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"period": period,
|
||||
"project_usage": {
|
||||
"total_requests": project_report.total_requests,
|
||||
"total_tokens": project_report.total_tokens,
|
||||
"total_cost_usd": project_report.total_cost_usd,
|
||||
"by_model": project_report.by_model,
|
||||
"period_start": project_report.period_start.isoformat(),
|
||||
"period_end": project_report.period_end.isoformat(),
|
||||
},
|
||||
"agent_usage": {
|
||||
"total_requests": agent_report.total_requests,
|
||||
"total_tokens": agent_report.total_tokens,
|
||||
"total_cost_usd": agent_report.total_cost_usd,
|
||||
"by_model": agent_report.by_model,
|
||||
"period_start": agent_report.period_start.isoformat(),
|
||||
"period_end": agent_report.period_end.isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def _impl_count_tokens(
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Core implementation for count_tokens."""
|
||||
# Use tiktoken for token counting
|
||||
# Default to cl100k_base (used by GPT-4, Claude, etc.)
|
||||
try:
|
||||
if model and model.startswith("gpt"):
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
else:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
token_count = len(encoding.encode(text))
|
||||
except Exception as e:
|
||||
logger.warning(f"Token counting failed: {e}, using estimate")
|
||||
# Fallback: rough estimate of ~4 chars per token
|
||||
token_count = len(text) // 4
|
||||
|
||||
# Estimate costs for different models
|
||||
cost_estimates: dict[str, float] = {}
|
||||
for model_name, config in MODEL_CONFIGS.items():
|
||||
if config.cost_per_1m_input > 0:
|
||||
cost = (token_count / 1_000_000) * config.cost_per_1m_input
|
||||
cost_estimates[model_name] = round(cost, 6)
|
||||
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"token_count": token_count,
|
||||
"text_length": len(text),
|
||||
"encoding": "cl100k_base",
|
||||
"cost_estimates": cost_estimates,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MCP Tools (wrappers around core implementations)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def chat_completion(
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
messages: list[dict],
|
||||
model_group: str = "high-reasoning",
|
||||
messages: list[dict[str, Any]],
|
||||
model_group: str = "reasoning",
|
||||
model_override: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> dict:
|
||||
stream: bool = False,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a chat completion using the specified model group.
|
||||
|
||||
@@ -41,108 +564,116 @@ async def chat_completion(
|
||||
project_id: UUID of the project (required for cost attribution)
|
||||
agent_id: UUID of the agent instance making the request
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
model_group: Model routing group (high-reasoning, code-generation, fast-response, cost-optimized, self-hosted)
|
||||
model_group: Model routing group (reasoning, code, fast, vision, embedding)
|
||||
model_override: Specific model to use (bypasses group routing)
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature (0.0-2.0)
|
||||
stream: Enable streaming response
|
||||
session_id: Optional session ID for tracking
|
||||
|
||||
Returns:
|
||||
Completion response with content and usage statistics
|
||||
"""
|
||||
# TODO: Implement with LiteLLM
|
||||
# 1. Map model_group to primary model + fallbacks
|
||||
# 2. Check project budget before making request
|
||||
# 3. Make completion request with failover
|
||||
# 4. Log usage via callback
|
||||
# 5. Return response
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"model_group": model_group,
|
||||
}
|
||||
return await _impl_chat_completion(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
messages=messages,
|
||||
model_group=model_group,
|
||||
model_override=model_override,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stream=stream,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_embeddings(
|
||||
async def list_models(
|
||||
project_id: str,
|
||||
texts: list[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
) -> dict:
|
||||
agent_id: str,
|
||||
model_group: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for the given texts.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project (required for cost attribution)
|
||||
texts: List of texts to embed
|
||||
model: Embedding model to use
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
# TODO: Implement with LiteLLM embeddings
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"project_id": project_id,
|
||||
"text_count": len(texts),
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_budget_status(project_id: str) -> dict:
|
||||
"""
|
||||
Get current budget status for a project.
|
||||
List available models and model groups.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
agent_id: UUID of the agent instance
|
||||
model_group: Optional specific group to list
|
||||
|
||||
Returns:
|
||||
Budget status with usage, limits, and percentage
|
||||
Dictionary of available models and groups
|
||||
"""
|
||||
# TODO: Implement budget check from Redis
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"project_id": project_id,
|
||||
}
|
||||
return await _impl_list_models(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
model_group=model_group,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_available_models() -> dict:
|
||||
async def get_usage(
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
period: str = "day",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List all available models and their capabilities.
|
||||
Get usage statistics for a project or agent.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
agent_id: UUID of the agent
|
||||
period: Time period (hour, day, month)
|
||||
|
||||
Returns:
|
||||
Dictionary of model groups and available models
|
||||
Usage statistics including tokens and costs
|
||||
"""
|
||||
return {
|
||||
"model_groups": {
|
||||
"high-reasoning": {
|
||||
"primary": "claude-opus-4-5",
|
||||
"fallbacks": ["gpt-5.1-codex-max", "gemini-3-pro"],
|
||||
"description": "Complex analysis, architecture decisions",
|
||||
},
|
||||
"code-generation": {
|
||||
"primary": "gpt-5.1-codex-max",
|
||||
"fallbacks": ["claude-opus-4-5", "deepseek-v3.2"],
|
||||
"description": "Code writing and refactoring",
|
||||
},
|
||||
"fast-response": {
|
||||
"primary": "gemini-3-flash",
|
||||
"fallbacks": ["qwen3-235b", "deepseek-v3.2"],
|
||||
"description": "Quick tasks, simple queries",
|
||||
},
|
||||
"cost-optimized": {
|
||||
"primary": "qwen3-235b",
|
||||
"fallbacks": ["deepseek-v3.2"],
|
||||
"description": "High-volume, non-critical tasks",
|
||||
},
|
||||
"self-hosted": {
|
||||
"primary": "deepseek-v3.2",
|
||||
"fallbacks": ["qwen3-235b"],
|
||||
"description": "Privacy-sensitive, air-gapped",
|
||||
},
|
||||
}
|
||||
}
|
||||
return await _impl_get_usage(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
period=period,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def count_tokens(
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Count tokens in text.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
agent_id: UUID of the agent
|
||||
text: Text to count tokens in
|
||||
model: Optional model for tokenizer selection
|
||||
|
||||
Returns:
|
||||
Token count and estimation details
|
||||
"""
|
||||
return await _impl_count_tokens(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
text=text,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the server."""
|
||||
import uvicorn
|
||||
|
||||
settings = get_settings()
|
||||
uvicorn.run(
|
||||
"server:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
main()
|
||||
|
||||
346
mcp-servers/llm-gateway/streaming.py
Normal file
346
mcp-servers/llm-gateway/streaming.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Streaming support for LLM Gateway.
|
||||
|
||||
Provides async streaming wrappers for LiteLLM responses.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from models import StreamChunk, UsageStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamAccumulator:
|
||||
"""
|
||||
Accumulates streaming chunks for cost calculation.
|
||||
|
||||
Tracks:
|
||||
- Full content for final response
|
||||
- Token counts from chunks
|
||||
- Timing information
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str | None = None) -> None:
|
||||
"""
|
||||
Initialize accumulator.
|
||||
|
||||
Args:
|
||||
request_id: Optional request ID for tracking
|
||||
"""
|
||||
self.request_id = request_id or str(uuid.uuid4())
|
||||
self.content_parts: list[str] = []
|
||||
self.chunks_received = 0
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.model: str | None = None
|
||||
self.finish_reason: str | None = None
|
||||
self._started_at: float | None = None
|
||||
self._finished_at: float | None = None
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""Get accumulated content."""
|
||||
return "".join(self.content_parts)
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get total token count."""
|
||||
return self.prompt_tokens + self.completion_tokens
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float | None:
|
||||
"""Get stream duration in seconds."""
|
||||
if self._started_at is None or self._finished_at is None:
|
||||
return None
|
||||
return self._finished_at - self._started_at
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark stream start."""
|
||||
import time
|
||||
|
||||
self._started_at = time.time()
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Mark stream finish."""
|
||||
import time
|
||||
|
||||
self._finished_at = time.time()
|
||||
|
||||
def add_chunk(
|
||||
self,
|
||||
delta: str,
|
||||
finish_reason: str | None = None,
|
||||
model: str | None = None,
|
||||
usage: dict[str, int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a chunk to the accumulator.
|
||||
|
||||
Args:
|
||||
delta: Content delta
|
||||
finish_reason: Finish reason if this is the final chunk
|
||||
model: Model name
|
||||
usage: Usage stats if provided
|
||||
"""
|
||||
if delta:
|
||||
self.content_parts.append(delta)
|
||||
self.chunks_received += 1
|
||||
|
||||
if finish_reason:
|
||||
self.finish_reason = finish_reason
|
||||
|
||||
if model:
|
||||
self.model = model
|
||||
|
||||
if usage:
|
||||
self.prompt_tokens = usage.get("prompt_tokens", self.prompt_tokens)
|
||||
self.completion_tokens = usage.get(
|
||||
"completion_tokens", self.completion_tokens
|
||||
)
|
||||
|
||||
def get_usage_stats(self, cost_usd: float = 0.0) -> UsageStats:
|
||||
"""Get usage statistics."""
|
||||
return UsageStats(
|
||||
prompt_tokens=self.prompt_tokens,
|
||||
completion_tokens=self.completion_tokens,
|
||||
total_tokens=self.total_tokens,
|
||||
cost_usd=cost_usd,
|
||||
)
|
||||
|
||||
|
||||
async def wrap_litellm_stream(
|
||||
stream: AsyncIterator[Any],
|
||||
accumulator: StreamAccumulator | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Wrap a LiteLLM stream into StreamChunk objects.
|
||||
|
||||
Args:
|
||||
stream: LiteLLM async stream
|
||||
accumulator: Optional accumulator for tracking
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
"""
|
||||
if accumulator:
|
||||
accumulator.start()
|
||||
|
||||
chunk_id = 0
|
||||
try:
|
||||
async for chunk in stream:
|
||||
chunk_id += 1
|
||||
|
||||
# Extract data from LiteLLM chunk
|
||||
delta = ""
|
||||
finish_reason = None
|
||||
usage = None
|
||||
model = None
|
||||
|
||||
# Handle different chunk formats
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta"):
|
||||
delta = getattr(choice.delta, "content", "") or ""
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
|
||||
if hasattr(chunk, "model"):
|
||||
model = chunk.model
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage = {
|
||||
"prompt_tokens": getattr(chunk.usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(chunk.usage, "completion_tokens", 0),
|
||||
}
|
||||
|
||||
# Update accumulator
|
||||
if accumulator:
|
||||
accumulator.add_chunk(
|
||||
delta=delta,
|
||||
finish_reason=finish_reason,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Create StreamChunk
|
||||
stream_chunk = StreamChunk(
|
||||
id=f"{accumulator.request_id if accumulator else 'stream'}-{chunk_id}",
|
||||
delta=delta,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
# Add usage on final chunk
|
||||
if finish_reason and accumulator:
|
||||
stream_chunk.usage = accumulator.get_usage_stats()
|
||||
|
||||
yield stream_chunk
|
||||
|
||||
finally:
|
||||
if accumulator:
|
||||
accumulator.finish()
|
||||
|
||||
|
||||
def format_sse_chunk(chunk: StreamChunk) -> str:
|
||||
"""
|
||||
Format a StreamChunk as SSE data.
|
||||
|
||||
Args:
|
||||
chunk: StreamChunk to format
|
||||
|
||||
Returns:
|
||||
SSE-formatted string
|
||||
"""
|
||||
data: dict[str, Any] = {
|
||||
"id": chunk.id,
|
||||
"delta": chunk.delta,
|
||||
}
|
||||
if chunk.finish_reason:
|
||||
data["finish_reason"] = chunk.finish_reason
|
||||
if chunk.usage:
|
||||
data["usage"] = chunk.usage.model_dump()
|
||||
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def format_sse_done() -> str:
|
||||
"""Format SSE done message."""
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def format_sse_error(error: str, code: str | None = None) -> str:
|
||||
"""
|
||||
Format an error as SSE data.
|
||||
|
||||
Args:
|
||||
error: Error message
|
||||
code: Error code
|
||||
|
||||
Returns:
|
||||
SSE-formatted error string
|
||||
"""
|
||||
data = {"error": error}
|
||||
if code:
|
||||
data["code"] = code
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
class StreamBuffer:
|
||||
"""
|
||||
Buffer for streaming responses with backpressure handling.
|
||||
|
||||
Useful when producing chunks faster than they can be consumed.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 100) -> None:
|
||||
"""
|
||||
Initialize buffer.
|
||||
|
||||
Args:
|
||||
max_size: Maximum buffer size
|
||||
"""
|
||||
self._queue: asyncio.Queue[StreamChunk | None] = asyncio.Queue(maxsize=max_size)
|
||||
self._done = False
|
||||
self._error: Exception | None = None
|
||||
|
||||
async def put(self, chunk: StreamChunk) -> None:
|
||||
"""
|
||||
Put a chunk in the buffer.
|
||||
|
||||
Args:
|
||||
chunk: Chunk to buffer
|
||||
"""
|
||||
if self._done:
|
||||
raise RuntimeError("Buffer is closed")
|
||||
await self._queue.put(chunk)
|
||||
|
||||
async def done(self) -> None:
|
||||
"""Signal that streaming is complete."""
|
||||
self._done = True
|
||||
await self._queue.put(None)
|
||||
|
||||
async def error(self, err: Exception) -> None:
|
||||
"""Signal an error occurred."""
|
||||
self._error = err
|
||||
self._done = True
|
||||
await self._queue.put(None)
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over buffered chunks."""
|
||||
while True:
|
||||
chunk = await self._queue.get()
|
||||
if chunk is None:
|
||||
if self._error:
|
||||
raise self._error
|
||||
return
|
||||
yield chunk
|
||||
|
||||
|
||||
async def stream_to_string(
|
||||
stream: AsyncIterator[StreamChunk],
|
||||
) -> tuple[str, UsageStats | None]:
|
||||
"""
|
||||
Consume a stream and return full content.
|
||||
|
||||
Args:
|
||||
stream: Stream to consume
|
||||
|
||||
Returns:
|
||||
Tuple of (content, usage_stats)
|
||||
"""
|
||||
parts: list[str] = []
|
||||
usage: UsageStats | None = None
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.delta:
|
||||
parts.append(chunk.delta)
|
||||
if chunk.usage:
|
||||
usage = chunk.usage
|
||||
|
||||
return "".join(parts), usage
|
||||
|
||||
|
||||
async def merge_streams(
|
||||
*streams: AsyncIterator[StreamChunk],
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Merge multiple streams into one.
|
||||
|
||||
Useful for parallel requests where results should be combined.
|
||||
|
||||
Args:
|
||||
*streams: Streams to merge
|
||||
|
||||
Yields:
|
||||
Chunks from all streams in arrival order
|
||||
"""
|
||||
pending: set[asyncio.Task[tuple[int, StreamChunk | None]]] = set()
|
||||
|
||||
async def next_chunk(
|
||||
idx: int, stream: AsyncIterator[StreamChunk]
|
||||
) -> tuple[int, StreamChunk | None]:
|
||||
try:
|
||||
return idx, await stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
return idx, None
|
||||
|
||||
# Start initial tasks
|
||||
active_streams = list(streams)
|
||||
for idx, stream in enumerate(active_streams):
|
||||
task = asyncio.create_task(next_chunk(idx, stream))
|
||||
pending.add(task)
|
||||
|
||||
while pending:
|
||||
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in done:
|
||||
idx, chunk = task.result()
|
||||
if chunk is not None:
|
||||
yield chunk
|
||||
# Schedule next chunk from this stream
|
||||
new_task = asyncio.create_task(next_chunk(idx, active_streams[idx]))
|
||||
pending.add(new_task)
|
||||
1
mcp-servers/llm-gateway/tests/__init__.py
Normal file
1
mcp-servers/llm-gateway/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for LLM Gateway MCP Server."""
|
||||
204
mcp-servers/llm-gateway/tests/conftest.py
Normal file
204
mcp-servers/llm-gateway/tests/conftest.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Pytest fixtures for LLM Gateway tests.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import fakeredis.aioredis
|
||||
import pytest
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from config import Settings
|
||||
from cost_tracking import CostTracker, reset_cost_tracker
|
||||
from failover import CircuitBreakerRegistry, reset_circuit_registry
|
||||
from providers import LLMProvider, reset_provider
|
||||
from routing import ModelRouter, reset_model_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings() -> Settings:
|
||||
"""Create test settings with mock API keys."""
|
||||
return Settings(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
debug=True,
|
||||
redis_url="redis://localhost:6379/0",
|
||||
anthropic_api_key="test-anthropic-key",
|
||||
openai_api_key="test-openai-key",
|
||||
google_api_key="test-google-key",
|
||||
litellm_timeout=30,
|
||||
litellm_max_retries=2,
|
||||
litellm_cache_enabled=False,
|
||||
cost_tracking_enabled=True,
|
||||
circuit_failure_threshold=3,
|
||||
circuit_recovery_timeout=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings_no_providers() -> Settings:
|
||||
"""Create settings with no providers configured."""
|
||||
return Settings(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
debug=False,
|
||||
redis_url="redis://localhost:6379/0",
|
||||
anthropic_api_key=None,
|
||||
openai_api_key=None,
|
||||
google_api_key=None,
|
||||
alibaba_api_key=None,
|
||||
deepseek_api_key=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis() -> fakeredis.aioredis.FakeRedis:
|
||||
"""Create a fake Redis instance for testing."""
|
||||
return fakeredis.aioredis.FakeRedis(decode_responses=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def cost_tracker(
|
||||
fake_redis: fakeredis.aioredis.FakeRedis,
|
||||
test_settings: Settings,
|
||||
) -> AsyncIterator[CostTracker]:
|
||||
"""Create a cost tracker with fake Redis."""
|
||||
reset_cost_tracker()
|
||||
tracker = CostTracker(redis_client=fake_redis, settings=test_settings)
|
||||
yield tracker
|
||||
await tracker.close()
|
||||
reset_cost_tracker()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def circuit_registry(test_settings: Settings) -> Iterator[CircuitBreakerRegistry]:
|
||||
"""Create a circuit breaker registry for testing."""
|
||||
reset_circuit_registry()
|
||||
registry = CircuitBreakerRegistry(settings=test_settings)
|
||||
yield registry
|
||||
reset_circuit_registry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_router(
|
||||
test_settings: Settings,
|
||||
circuit_registry: CircuitBreakerRegistry,
|
||||
) -> Iterator[ModelRouter]:
|
||||
"""Create a model router for testing."""
|
||||
reset_model_router()
|
||||
router = ModelRouter(settings=test_settings, circuit_registry=circuit_registry)
|
||||
yield router
|
||||
reset_model_router()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(test_settings: Settings) -> Iterator[LLMProvider]:
|
||||
"""Create an LLM provider for testing."""
|
||||
reset_provider()
|
||||
provider = LLMProvider(settings=test_settings)
|
||||
yield provider
|
||||
reset_provider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_response() -> dict[str, Any]:
|
||||
"""Create a mock LiteLLM response."""
|
||||
return {
|
||||
"id": "test-response-id",
|
||||
"model": "claude-opus-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "This is a test response.",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion_response() -> MagicMock:
|
||||
"""Create a mock completion response object."""
|
||||
response = MagicMock()
|
||||
response.id = "test-response-id"
|
||||
response.model = "claude-opus-4"
|
||||
|
||||
choice = MagicMock()
|
||||
choice.index = 0
|
||||
choice.message = MagicMock()
|
||||
choice.message.content = "This is a test response."
|
||||
choice.finish_reason = "stop"
|
||||
response.choices = [choice]
|
||||
|
||||
response.usage = MagicMock()
|
||||
response.usage.prompt_tokens = 10
|
||||
response.usage.completion_tokens = 20
|
||||
response.usage.total_tokens = 30
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages() -> list[dict[str, str]]:
|
||||
"""Sample chat messages for testing."""
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_project_id() -> str:
|
||||
"""Sample project ID."""
|
||||
return "proj-12345678-1234-1234-1234-123456789abc"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_id() -> str:
|
||||
"""Sample agent ID."""
|
||||
return "agent-87654321-4321-4321-4321-cba987654321"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session_id() -> str:
|
||||
"""Sample session ID."""
|
||||
return "session-11111111-2222-3333-4444-555555555555"
|
||||
|
||||
|
||||
# Reset all global state after each test
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals() -> Iterator[None]:
|
||||
"""Reset all global state after each test."""
|
||||
yield
|
||||
reset_cost_tracker()
|
||||
reset_circuit_registry()
|
||||
reset_model_router()
|
||||
reset_provider()
|
||||
|
||||
|
||||
# Mock environment variables for tests
|
||||
# Note: Not autouse=True to avoid affecting default value tests
|
||||
@pytest.fixture
|
||||
def mock_env_vars() -> Iterator[None]:
|
||||
"""Set test environment variables."""
|
||||
env_vars = {
|
||||
"LLM_GATEWAY_HOST": "127.0.0.1",
|
||||
"LLM_GATEWAY_PORT": "8001",
|
||||
"LLM_GATEWAY_DEBUG": "true",
|
||||
}
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
yield
|
||||
200
mcp-servers/llm-gateway/tests/test_config.py
Normal file
200
mcp-servers/llm-gateway/tests/test_config.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Tests for config module.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings, get_settings
|
||||
|
||||
|
||||
class TestSettings:
|
||||
"""Tests for Settings class."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default configuration values."""
|
||||
settings = Settings()
|
||||
|
||||
assert settings.host == "0.0.0.0"
|
||||
assert settings.port == 8001
|
||||
assert settings.debug is False
|
||||
assert settings.redis_url == "redis://localhost:6379/0"
|
||||
assert settings.litellm_timeout == 120
|
||||
assert settings.circuit_failure_threshold == 5
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test custom configuration values."""
|
||||
settings = Settings(
|
||||
host="127.0.0.1",
|
||||
port=9000,
|
||||
debug=True,
|
||||
redis_url="redis://custom:6380/1",
|
||||
litellm_timeout=60,
|
||||
)
|
||||
|
||||
assert settings.host == "127.0.0.1"
|
||||
assert settings.port == 9000
|
||||
assert settings.debug is True
|
||||
assert settings.redis_url == "redis://custom:6380/1"
|
||||
assert settings.litellm_timeout == 60
|
||||
|
||||
def test_port_validation_valid(self) -> None:
|
||||
"""Test valid port numbers."""
|
||||
settings = Settings(port=1)
|
||||
assert settings.port == 1
|
||||
|
||||
settings = Settings(port=65535)
|
||||
assert settings.port == 65535
|
||||
|
||||
settings = Settings(port=8080)
|
||||
assert settings.port == 8080
|
||||
|
||||
def test_port_validation_invalid(self) -> None:
|
||||
"""Test invalid port numbers."""
|
||||
with pytest.raises(ValueError, match="Port must be between"):
|
||||
Settings(port=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Port must be between"):
|
||||
Settings(port=65536)
|
||||
|
||||
with pytest.raises(ValueError, match="Port must be between"):
|
||||
Settings(port=-1)
|
||||
|
||||
def test_ttl_validation_valid(self) -> None:
|
||||
"""Test valid TTL values."""
|
||||
settings = Settings(redis_ttl_hours=1)
|
||||
assert settings.redis_ttl_hours == 1
|
||||
|
||||
settings = Settings(redis_ttl_hours=168) # 1 week
|
||||
assert settings.redis_ttl_hours == 168
|
||||
|
||||
def test_ttl_validation_invalid(self) -> None:
|
||||
"""Test invalid TTL values."""
|
||||
with pytest.raises(ValueError, match="Redis TTL must be positive"):
|
||||
Settings(redis_ttl_hours=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Redis TTL must be positive"):
|
||||
Settings(redis_ttl_hours=-1)
|
||||
|
||||
def test_failure_threshold_validation(self) -> None:
|
||||
"""Test circuit failure threshold validation."""
|
||||
settings = Settings(circuit_failure_threshold=1)
|
||||
assert settings.circuit_failure_threshold == 1
|
||||
|
||||
settings = Settings(circuit_failure_threshold=100)
|
||||
assert settings.circuit_failure_threshold == 100
|
||||
|
||||
with pytest.raises(ValueError, match="Failure threshold must be between"):
|
||||
Settings(circuit_failure_threshold=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Failure threshold must be between"):
|
||||
Settings(circuit_failure_threshold=101)
|
||||
|
||||
def test_timeout_validation(self) -> None:
|
||||
"""Test timeout validation."""
|
||||
settings = Settings(litellm_timeout=1)
|
||||
assert settings.litellm_timeout == 1
|
||||
|
||||
settings = Settings(litellm_timeout=600)
|
||||
assert settings.litellm_timeout == 600
|
||||
|
||||
with pytest.raises(ValueError, match="Timeout must be between"):
|
||||
Settings(litellm_timeout=0)
|
||||
|
||||
with pytest.raises(ValueError, match="Timeout must be between"):
|
||||
Settings(litellm_timeout=601)
|
||||
|
||||
def test_get_available_providers_none(self) -> None:
|
||||
"""Test getting available providers with none configured."""
|
||||
settings = Settings()
|
||||
providers = settings.get_available_providers()
|
||||
assert providers == []
|
||||
|
||||
def test_get_available_providers_some(self) -> None:
|
||||
"""Test getting available providers with some configured."""
|
||||
settings = Settings(
|
||||
anthropic_api_key="test-key",
|
||||
openai_api_key="test-key",
|
||||
)
|
||||
providers = settings.get_available_providers()
|
||||
|
||||
assert "anthropic" in providers
|
||||
assert "openai" in providers
|
||||
assert "google" not in providers
|
||||
assert len(providers) == 2
|
||||
|
||||
def test_get_available_providers_all(self) -> None:
|
||||
"""Test getting available providers with all configured."""
|
||||
settings = Settings(
|
||||
anthropic_api_key="test-key",
|
||||
openai_api_key="test-key",
|
||||
google_api_key="test-key",
|
||||
alibaba_api_key="test-key",
|
||||
deepseek_api_key="test-key",
|
||||
)
|
||||
providers = settings.get_available_providers()
|
||||
|
||||
assert len(providers) == 5
|
||||
assert "anthropic" in providers
|
||||
assert "openai" in providers
|
||||
assert "google" in providers
|
||||
assert "alibaba" in providers
|
||||
assert "deepseek" in providers
|
||||
|
||||
def test_has_any_provider_false(self) -> None:
|
||||
"""Test has_any_provider when none configured."""
|
||||
settings = Settings()
|
||||
assert settings.has_any_provider() is False
|
||||
|
||||
def test_has_any_provider_true(self) -> None:
|
||||
"""Test has_any_provider when at least one configured."""
|
||||
settings = Settings(anthropic_api_key="test-key")
|
||||
assert settings.has_any_provider() is True
|
||||
|
||||
def test_deepseek_base_url_counts_as_provider(self) -> None:
|
||||
"""Test that DeepSeek base URL alone counts as provider."""
|
||||
settings = Settings(deepseek_base_url="http://localhost:8000")
|
||||
providers = settings.get_available_providers()
|
||||
assert "deepseek" in providers
|
||||
|
||||
|
||||
class TestGetSettings:
|
||||
"""Tests for get_settings function."""
|
||||
|
||||
def test_get_settings_returns_settings(self) -> None:
|
||||
"""Test that get_settings returns a Settings instance."""
|
||||
# Clear the cache first
|
||||
get_settings.cache_clear()
|
||||
|
||||
settings = get_settings()
|
||||
assert isinstance(settings, Settings)
|
||||
|
||||
def test_get_settings_is_cached(self) -> None:
|
||||
"""Test that get_settings returns cached instance."""
|
||||
get_settings.cache_clear()
|
||||
|
||||
settings1 = get_settings()
|
||||
settings2 = get_settings()
|
||||
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_env_var_override(self) -> None:
|
||||
"""Test that environment variables override defaults."""
|
||||
get_settings.cache_clear()
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LLM_GATEWAY_HOST": "192.168.1.1",
|
||||
"LLM_GATEWAY_PORT": "9999",
|
||||
"LLM_GATEWAY_DEBUG": "true",
|
||||
},
|
||||
):
|
||||
get_settings.cache_clear()
|
||||
settings = get_settings()
|
||||
|
||||
assert settings.host == "192.168.1.1"
|
||||
assert settings.port == 9999
|
||||
assert settings.debug is True
|
||||
405
mcp-servers/llm-gateway/tests/test_cost_tracking.py
Normal file
405
mcp-servers/llm-gateway/tests/test_cost_tracking.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Tests for cost_tracking module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import fakeredis.aioredis
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from cost_tracking import (
|
||||
CostTracker,
|
||||
calculate_cost,
|
||||
close_cost_tracker,
|
||||
get_cost_tracker,
|
||||
reset_cost_tracker,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker_settings() -> Settings:
|
||||
"""Settings for cost tracker tests."""
|
||||
return Settings(
|
||||
redis_url="redis://localhost:6379/0",
|
||||
redis_prefix="test_llm_gateway",
|
||||
cost_tracking_enabled=True,
|
||||
cost_alert_threshold=100.0,
|
||||
default_budget_limit=1000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis() -> fakeredis.aioredis.FakeRedis:
|
||||
"""Create fake Redis for testing."""
|
||||
return fakeredis.aioredis.FakeRedis(decode_responses=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker(
|
||||
fake_redis: fakeredis.aioredis.FakeRedis,
|
||||
tracker_settings: Settings,
|
||||
) -> CostTracker:
|
||||
"""Create cost tracker with fake Redis."""
|
||||
return CostTracker(redis_client=fake_redis, settings=tracker_settings)
|
||||
|
||||
|
||||
class TestCalculateCost:
|
||||
"""Tests for calculate_cost function."""
|
||||
|
||||
def test_calculate_cost_known_model(self) -> None:
|
||||
"""Test calculating cost for known model."""
|
||||
# claude-opus-4: $15/1M input, $75/1M output
|
||||
cost = calculate_cost(
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
)
|
||||
|
||||
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
|
||||
assert cost == pytest.approx(0.0525, rel=0.001)
|
||||
|
||||
def test_calculate_cost_unknown_model(self) -> None:
|
||||
"""Test calculating cost for unknown model."""
|
||||
cost = calculate_cost(
|
||||
model="unknown-model",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
)
|
||||
|
||||
assert cost == 0.0
|
||||
|
||||
def test_calculate_cost_zero_tokens(self) -> None:
|
||||
"""Test calculating cost with zero tokens."""
|
||||
cost = calculate_cost(
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
assert cost == 0.0
|
||||
|
||||
def test_calculate_cost_large_token_counts(self) -> None:
|
||||
"""Test calculating cost with large token counts."""
|
||||
cost = calculate_cost(
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1_000_000,
|
||||
completion_tokens=500_000,
|
||||
)
|
||||
|
||||
# 1M * 15/1M + 500K * 75/1M = 15 + 37.5 = 52.5
|
||||
assert cost == pytest.approx(52.5, rel=0.001)
|
||||
|
||||
|
||||
class TestCostTracker:
|
||||
"""Tests for CostTracker class."""
|
||||
|
||||
def test_record_usage(self, tracker: CostTracker) -> None:
|
||||
"""Test recording usage."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify by getting usage report
|
||||
report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
|
||||
|
||||
assert report.total_requests == 1
|
||||
assert report.total_cost_usd == pytest.approx(0.01, rel=0.01)
|
||||
|
||||
def test_record_usage_disabled(self, tracker_settings: Settings) -> None:
|
||||
"""Test recording is skipped when disabled."""
|
||||
settings = Settings(
|
||||
**{
|
||||
**tracker_settings.model_dump(),
|
||||
"cost_tracking_enabled": False,
|
||||
}
|
||||
)
|
||||
fake_redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
|
||||
disabled_tracker = CostTracker(redis_client=fake_redis, settings=settings)
|
||||
|
||||
# This should not raise and should not record
|
||||
asyncio.run(
|
||||
disabled_tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
# Usage should be empty
|
||||
report = asyncio.run(
|
||||
disabled_tracker.get_project_usage("proj-123", period="day")
|
||||
)
|
||||
assert report.total_requests == 0
|
||||
|
||||
def test_record_usage_with_session(self, tracker: CostTracker) -> None:
|
||||
"""Test recording usage with session ID."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
session_id="session-789",
|
||||
)
|
||||
)
|
||||
|
||||
# Verify session usage
|
||||
session_usage = asyncio.run(tracker.get_session_usage("session-789"))
|
||||
|
||||
assert session_usage["session_id"] == "session-789"
|
||||
assert session_usage["total_cost_usd"] == pytest.approx(0.01, rel=0.01)
|
||||
|
||||
def test_get_project_usage_empty(self, tracker: CostTracker) -> None:
|
||||
"""Test getting usage for project with no data."""
|
||||
report = asyncio.run(
|
||||
tracker.get_project_usage("nonexistent-project", period="day")
|
||||
)
|
||||
|
||||
assert report.entity_id == "nonexistent-project"
|
||||
assert report.entity_type == "project"
|
||||
assert report.total_requests == 0
|
||||
assert report.total_cost_usd == 0.0
|
||||
|
||||
def test_get_project_usage_multiple_models(self, tracker: CostTracker) -> None:
|
||||
"""Test usage tracking across multiple models."""
|
||||
# Record usage for different models
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="gpt-4.1",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cost_usd=0.02,
|
||||
)
|
||||
)
|
||||
|
||||
report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
|
||||
|
||||
assert report.total_requests == 2
|
||||
assert len(report.by_model) == 2
|
||||
assert "claude-opus-4" in report.by_model
|
||||
assert "gpt-4.1" in report.by_model
|
||||
|
||||
def test_get_agent_usage(self, tracker: CostTracker) -> None:
|
||||
"""Test getting agent usage."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
report = asyncio.run(tracker.get_agent_usage("agent-456", period="day"))
|
||||
|
||||
assert report.entity_id == "agent-456"
|
||||
assert report.entity_type == "agent"
|
||||
assert report.total_requests == 1
|
||||
|
||||
def test_usage_periods(self, tracker: CostTracker) -> None:
|
||||
"""Test different usage periods."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
# Check different periods
|
||||
hour_report = asyncio.run(tracker.get_project_usage("proj-123", period="hour"))
|
||||
day_report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
|
||||
month_report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="month")
|
||||
)
|
||||
|
||||
assert hour_report.period == "hour"
|
||||
assert day_report.period == "day"
|
||||
assert month_report.period == "month"
|
||||
|
||||
# All should have the same data
|
||||
assert hour_report.total_requests == 1
|
||||
assert day_report.total_requests == 1
|
||||
assert month_report.total_requests == 1
|
||||
|
||||
def test_check_budget_within(self, tracker: CostTracker) -> None:
|
||||
"""Test budget check when within limit."""
|
||||
# Record some usage
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
cost_usd=50.0,
|
||||
)
|
||||
)
|
||||
|
||||
within, current, limit = asyncio.run(
|
||||
tracker.check_budget("proj-123", budget_limit=100.0)
|
||||
)
|
||||
|
||||
assert within is True
|
||||
assert current == pytest.approx(50.0, rel=0.01)
|
||||
assert limit == 100.0
|
||||
|
||||
def test_check_budget_exceeded(self, tracker: CostTracker) -> None:
|
||||
"""Test budget check when exceeded."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
cost_usd=150.0,
|
||||
)
|
||||
)
|
||||
|
||||
within, current, limit = asyncio.run(
|
||||
tracker.check_budget("proj-123", budget_limit=100.0)
|
||||
)
|
||||
|
||||
assert within is False
|
||||
assert current >= limit
|
||||
|
||||
def test_check_budget_default_limit(self, tracker: CostTracker) -> None:
|
||||
"""Test budget check with default limit."""
|
||||
within, current, limit = asyncio.run(tracker.check_budget("proj-123"))
|
||||
|
||||
assert limit == 1000.0 # Default from settings
|
||||
|
||||
def test_estimate_request_cost_known_model(self, tracker: CostTracker) -> None:
|
||||
"""Test estimating cost for known model."""
|
||||
cost = asyncio.run(
|
||||
tracker.estimate_request_cost(
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
max_completion_tokens=500,
|
||||
)
|
||||
)
|
||||
|
||||
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
|
||||
assert cost == pytest.approx(0.0525, rel=0.01)
|
||||
|
||||
def test_estimate_request_cost_unknown_model(self, tracker: CostTracker) -> None:
|
||||
"""Test estimating cost for unknown model."""
|
||||
cost = asyncio.run(
|
||||
tracker.estimate_request_cost(
|
||||
model="unknown-model",
|
||||
prompt_tokens=1000,
|
||||
max_completion_tokens=500,
|
||||
)
|
||||
)
|
||||
|
||||
# Uses fallback estimate
|
||||
assert cost > 0
|
||||
|
||||
def test_should_alert_below_threshold(self, tracker: CostTracker) -> None:
|
||||
"""Test alert check when below threshold."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
cost_usd=50.0,
|
||||
)
|
||||
)
|
||||
|
||||
should_alert, current = asyncio.run(
|
||||
tracker.should_alert("proj-123", threshold=100.0)
|
||||
)
|
||||
|
||||
assert should_alert is False
|
||||
assert current == pytest.approx(50.0, rel=0.01)
|
||||
|
||||
def test_should_alert_above_threshold(self, tracker: CostTracker) -> None:
|
||||
"""Test alert check when above threshold."""
|
||||
asyncio.run(
|
||||
tracker.record_usage(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
cost_usd=150.0,
|
||||
)
|
||||
)
|
||||
|
||||
should_alert, current = asyncio.run(
|
||||
tracker.should_alert("proj-123", threshold=100.0)
|
||||
)
|
||||
|
||||
assert should_alert is True
|
||||
|
||||
def test_close(self, tracker: CostTracker) -> None:
|
||||
"""Test closing tracker."""
|
||||
asyncio.run(tracker.close())
|
||||
assert tracker._redis is None
|
||||
|
||||
|
||||
class TestGlobalTracker:
|
||||
"""Tests for global tracker functions."""
|
||||
|
||||
def test_get_cost_tracker(self) -> None:
|
||||
"""Test getting global tracker."""
|
||||
reset_cost_tracker()
|
||||
tracker = get_cost_tracker()
|
||||
assert isinstance(tracker, CostTracker)
|
||||
|
||||
def test_get_cost_tracker_singleton(self) -> None:
|
||||
"""Test tracker is singleton."""
|
||||
reset_cost_tracker()
|
||||
tracker1 = get_cost_tracker()
|
||||
tracker2 = get_cost_tracker()
|
||||
assert tracker1 is tracker2
|
||||
|
||||
def test_reset_cost_tracker(self) -> None:
|
||||
"""Test resetting global tracker."""
|
||||
reset_cost_tracker()
|
||||
tracker1 = get_cost_tracker()
|
||||
reset_cost_tracker()
|
||||
tracker2 = get_cost_tracker()
|
||||
assert tracker1 is not tracker2
|
||||
|
||||
def test_close_cost_tracker(self) -> None:
|
||||
"""Test closing global tracker."""
|
||||
reset_cost_tracker()
|
||||
_ = get_cost_tracker()
|
||||
asyncio.run(close_cost_tracker())
|
||||
# Getting again should create a new one
|
||||
tracker2 = get_cost_tracker()
|
||||
assert tracker2 is not None
|
||||
376
mcp-servers/llm-gateway/tests/test_exceptions.py
Normal file
376
mcp-servers/llm-gateway/tests/test_exceptions.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Tests for exceptions module.
|
||||
"""
|
||||
|
||||
from exceptions import (
|
||||
AllProvidersFailedError,
|
||||
CircuitOpenError,
|
||||
ConfigurationError,
|
||||
ContextTooLongError,
|
||||
CostLimitExceededError,
|
||||
ErrorCode,
|
||||
InvalidModelError,
|
||||
InvalidModelGroupError,
|
||||
LLMGatewayError,
|
||||
ModelNotAvailableError,
|
||||
ProviderError,
|
||||
RateLimitError,
|
||||
StreamError,
|
||||
TokenLimitExceededError,
|
||||
)
|
||||
|
||||
|
||||
class TestErrorCode:
|
||||
"""Tests for ErrorCode enum."""
|
||||
|
||||
def test_error_code_values(self) -> None:
|
||||
"""Test error code values."""
|
||||
assert ErrorCode.UNKNOWN_ERROR.value == "LLM_UNKNOWN_ERROR"
|
||||
assert ErrorCode.PROVIDER_ERROR.value == "LLM_PROVIDER_ERROR"
|
||||
assert ErrorCode.CIRCUIT_OPEN.value == "LLM_CIRCUIT_OPEN"
|
||||
assert ErrorCode.COST_LIMIT_EXCEEDED.value == "LLM_COST_LIMIT_EXCEEDED"
|
||||
|
||||
|
||||
class TestLLMGatewayError:
|
||||
"""Tests for LLMGatewayError base class."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic error creation."""
|
||||
error = LLMGatewayError("Something went wrong")
|
||||
assert str(error) == "[LLM_UNKNOWN_ERROR] Something went wrong"
|
||||
assert error.message == "Something went wrong"
|
||||
assert error.code == ErrorCode.UNKNOWN_ERROR
|
||||
assert error.details == {}
|
||||
assert error.cause is None
|
||||
|
||||
def test_error_with_code(self) -> None:
|
||||
"""Test error with custom code."""
|
||||
error = LLMGatewayError(
|
||||
"Provider failed",
|
||||
code=ErrorCode.PROVIDER_ERROR,
|
||||
)
|
||||
assert error.code == ErrorCode.PROVIDER_ERROR
|
||||
|
||||
def test_error_with_details(self) -> None:
|
||||
"""Test error with details."""
|
||||
error = LLMGatewayError(
|
||||
"Error",
|
||||
details={"key": "value"},
|
||||
)
|
||||
assert error.details == {"key": "value"}
|
||||
|
||||
def test_error_with_cause(self) -> None:
|
||||
"""Test error with cause exception."""
|
||||
cause = ValueError("Original error")
|
||||
error = LLMGatewayError("Wrapped error", cause=cause)
|
||||
assert error.cause is cause
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting error to dict."""
|
||||
error = LLMGatewayError(
|
||||
"Test error",
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
details={"field": "value"},
|
||||
)
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["error"] == "LLM_INVALID_REQUEST"
|
||||
assert result["message"] == "Test error"
|
||||
assert result["details"] == {"field": "value"}
|
||||
|
||||
def test_to_dict_no_details(self) -> None:
|
||||
"""Test to_dict without details."""
|
||||
error = LLMGatewayError("Test error")
|
||||
result = error.to_dict()
|
||||
|
||||
assert "details" not in result
|
||||
|
||||
def test_repr(self) -> None:
|
||||
"""Test error repr."""
|
||||
error = LLMGatewayError("Test", details={"key": "val"})
|
||||
repr_str = repr(error)
|
||||
|
||||
assert "LLMGatewayError" in repr_str
|
||||
assert "Test" in repr_str
|
||||
|
||||
|
||||
class TestProviderError:
|
||||
"""Tests for ProviderError."""
|
||||
|
||||
def test_basic_provider_error(self) -> None:
|
||||
"""Test basic provider error."""
|
||||
error = ProviderError(
|
||||
message="API call failed",
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
assert error.provider == "anthropic"
|
||||
assert error.model is None
|
||||
assert error.status_code is None
|
||||
assert error.code == ErrorCode.PROVIDER_ERROR
|
||||
assert "provider" in error.details
|
||||
|
||||
def test_provider_error_with_model(self) -> None:
|
||||
"""Test provider error with model info."""
|
||||
error = ProviderError(
|
||||
message="Model not found",
|
||||
provider="openai",
|
||||
model="gpt-5",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
assert error.model == "gpt-5"
|
||||
assert error.status_code == 404
|
||||
assert error.details["model"] == "gpt-5"
|
||||
assert error.details["status_code"] == 404
|
||||
|
||||
def test_provider_error_with_cause(self) -> None:
|
||||
"""Test provider error with cause."""
|
||||
cause = ConnectionError("Network down")
|
||||
error = ProviderError(
|
||||
message="Connection failed",
|
||||
provider="google",
|
||||
cause=cause,
|
||||
)
|
||||
|
||||
assert error.cause is cause
|
||||
|
||||
|
||||
class TestRateLimitError:
|
||||
"""Tests for RateLimitError."""
|
||||
|
||||
def test_internal_rate_limit(self) -> None:
|
||||
"""Test internal rate limit error."""
|
||||
error = RateLimitError(
|
||||
message="Too many requests",
|
||||
retry_after=60,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.RATE_LIMIT_EXCEEDED
|
||||
assert error.provider is None
|
||||
assert error.retry_after == 60
|
||||
assert error.details["retry_after_seconds"] == 60
|
||||
|
||||
def test_provider_rate_limit(self) -> None:
|
||||
"""Test provider rate limit error."""
|
||||
error = RateLimitError(
|
||||
message="OpenAI rate limit",
|
||||
provider="openai",
|
||||
retry_after=30,
|
||||
)
|
||||
|
||||
assert error.code == ErrorCode.PROVIDER_RATE_LIMIT
|
||||
assert error.provider == "openai"
|
||||
assert error.details["provider"] == "openai"
|
||||
|
||||
|
||||
class TestCircuitOpenError:
|
||||
"""Tests for CircuitOpenError."""
|
||||
|
||||
def test_circuit_open_error(self) -> None:
|
||||
"""Test circuit open error."""
|
||||
error = CircuitOpenError(
|
||||
provider="anthropic",
|
||||
recovery_time=45,
|
||||
)
|
||||
|
||||
assert error.provider == "anthropic"
|
||||
assert error.recovery_time == 45
|
||||
assert error.code == ErrorCode.CIRCUIT_OPEN
|
||||
assert "Circuit breaker open" in error.message
|
||||
assert error.details["recovery_time_seconds"] == 45
|
||||
|
||||
def test_circuit_open_no_recovery_time(self) -> None:
|
||||
"""Test circuit open without recovery time."""
|
||||
error = CircuitOpenError(provider="openai")
|
||||
|
||||
assert error.recovery_time is None
|
||||
assert "recovery_time_seconds" not in error.details
|
||||
|
||||
|
||||
class TestCostLimitExceededError:
|
||||
"""Tests for CostLimitExceededError."""
|
||||
|
||||
def test_project_cost_limit(self) -> None:
|
||||
"""Test project cost limit error."""
|
||||
error = CostLimitExceededError(
|
||||
entity_type="project",
|
||||
entity_id="proj-123",
|
||||
current_cost=150.0,
|
||||
limit=100.0,
|
||||
)
|
||||
|
||||
assert error.entity_type == "project"
|
||||
assert error.entity_id == "proj-123"
|
||||
assert error.current_cost == 150.0
|
||||
assert error.limit == 100.0
|
||||
assert error.code == ErrorCode.COST_LIMIT_EXCEEDED
|
||||
assert "$150.00" in error.message
|
||||
assert "$100.00" in error.message
|
||||
|
||||
def test_agent_cost_limit(self) -> None:
|
||||
"""Test agent cost limit error."""
|
||||
error = CostLimitExceededError(
|
||||
entity_type="agent",
|
||||
entity_id="agent-456",
|
||||
current_cost=50.0,
|
||||
limit=25.0,
|
||||
)
|
||||
|
||||
assert error.entity_type == "agent"
|
||||
assert error.details["entity_type"] == "agent"
|
||||
|
||||
|
||||
class TestInvalidModelGroupError:
|
||||
"""Tests for InvalidModelGroupError."""
|
||||
|
||||
def test_invalid_group_error(self) -> None:
|
||||
"""Test invalid model group error."""
|
||||
error = InvalidModelGroupError(
|
||||
model_group="invalid_group",
|
||||
available_groups=["reasoning", "code", "fast"],
|
||||
)
|
||||
|
||||
assert error.model_group == "invalid_group"
|
||||
assert error.available_groups == ["reasoning", "code", "fast"]
|
||||
assert error.code == ErrorCode.INVALID_MODEL_GROUP
|
||||
assert "invalid_group" in error.message
|
||||
|
||||
def test_invalid_group_no_available(self) -> None:
|
||||
"""Test invalid group without available list."""
|
||||
error = InvalidModelGroupError(model_group="unknown")
|
||||
|
||||
assert error.available_groups is None
|
||||
assert "available_groups" not in error.details
|
||||
|
||||
|
||||
class TestInvalidModelError:
|
||||
"""Tests for InvalidModelError."""
|
||||
|
||||
def test_invalid_model_error(self) -> None:
|
||||
"""Test invalid model error."""
|
||||
error = InvalidModelError(
|
||||
model="gpt-99",
|
||||
reason="Model does not exist",
|
||||
)
|
||||
|
||||
assert error.model == "gpt-99"
|
||||
assert error.code == ErrorCode.INVALID_MODEL
|
||||
assert "gpt-99" in error.message
|
||||
assert "Model does not exist" in error.message
|
||||
|
||||
def test_invalid_model_no_reason(self) -> None:
|
||||
"""Test invalid model without reason."""
|
||||
error = InvalidModelError(model="unknown-model")
|
||||
|
||||
assert "reason" not in error.details
|
||||
|
||||
|
||||
class TestModelNotAvailableError:
|
||||
"""Tests for ModelNotAvailableError."""
|
||||
|
||||
def test_model_not_available(self) -> None:
|
||||
"""Test model not available error."""
|
||||
error = ModelNotAvailableError(
|
||||
model="claude-opus-4",
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
assert error.model == "claude-opus-4"
|
||||
assert error.provider == "anthropic"
|
||||
assert error.code == ErrorCode.MODEL_NOT_AVAILABLE
|
||||
assert "not configured" in error.message
|
||||
|
||||
|
||||
class TestAllProvidersFailedError:
|
||||
"""Tests for AllProvidersFailedError."""
|
||||
|
||||
def test_all_providers_failed(self) -> None:
|
||||
"""Test all providers failed error."""
|
||||
errors = [
|
||||
{"model": "claude-opus-4", "error": "Rate limited"},
|
||||
{"model": "gpt-4.1", "error": "Timeout"},
|
||||
]
|
||||
error = AllProvidersFailedError(
|
||||
model_group="reasoning",
|
||||
attempted_models=["claude-opus-4", "gpt-4.1"],
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
assert error.model_group == "reasoning"
|
||||
assert error.attempted_models == ["claude-opus-4", "gpt-4.1"]
|
||||
assert error.errors == errors
|
||||
assert error.code == ErrorCode.ALL_PROVIDERS_FAILED
|
||||
|
||||
|
||||
class TestStreamError:
|
||||
"""Tests for StreamError."""
|
||||
|
||||
def test_stream_error(self) -> None:
|
||||
"""Test stream error."""
|
||||
cause = OSError("Connection reset")
|
||||
error = StreamError(
|
||||
message="Stream interrupted",
|
||||
chunks_received=10,
|
||||
cause=cause,
|
||||
)
|
||||
|
||||
assert error.chunks_received == 10
|
||||
assert error.cause is cause
|
||||
assert error.code == ErrorCode.STREAM_ERROR
|
||||
|
||||
|
||||
class TestTokenLimitExceededError:
|
||||
"""Tests for TokenLimitExceededError."""
|
||||
|
||||
def test_token_limit_exceeded(self) -> None:
|
||||
"""Test token limit exceeded error."""
|
||||
error = TokenLimitExceededError(
|
||||
model="claude-haiku",
|
||||
token_count=10000,
|
||||
limit=8192,
|
||||
)
|
||||
|
||||
assert error.model == "claude-haiku"
|
||||
assert error.token_count == 10000
|
||||
assert error.limit == 8192
|
||||
assert error.code == ErrorCode.TOKEN_LIMIT_EXCEEDED
|
||||
|
||||
|
||||
class TestContextTooLongError:
|
||||
"""Tests for ContextTooLongError."""
|
||||
|
||||
def test_context_too_long(self) -> None:
|
||||
"""Test context too long error."""
|
||||
error = ContextTooLongError(
|
||||
model="gpt-4.1-mini",
|
||||
context_length=150000,
|
||||
max_context=100000,
|
||||
)
|
||||
|
||||
assert error.model == "gpt-4.1-mini"
|
||||
assert error.context_length == 150000
|
||||
assert error.max_context == 100000
|
||||
assert error.code == ErrorCode.CONTEXT_TOO_LONG
|
||||
|
||||
|
||||
class TestConfigurationError:
|
||||
"""Tests for ConfigurationError."""
|
||||
|
||||
def test_configuration_error(self) -> None:
|
||||
"""Test configuration error."""
|
||||
error = ConfigurationError(
|
||||
message="Missing API key",
|
||||
config_key="ANTHROPIC_API_KEY",
|
||||
)
|
||||
|
||||
assert error.config_key == "ANTHROPIC_API_KEY"
|
||||
assert error.code == ErrorCode.CONFIGURATION_ERROR
|
||||
assert error.details["config_key"] == "ANTHROPIC_API_KEY"
|
||||
|
||||
def test_configuration_error_no_key(self) -> None:
|
||||
"""Test configuration error without key."""
|
||||
error = ConfigurationError(message="Invalid configuration")
|
||||
|
||||
assert error.config_key is None
|
||||
assert "config_key" not in error.details
|
||||
407
mcp-servers/llm-gateway/tests/test_failover.py
Normal file
407
mcp-servers/llm-gateway/tests/test_failover.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
Tests for failover module (circuit breaker).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from exceptions import CircuitOpenError
|
||||
from failover import (
|
||||
CircuitBreaker,
|
||||
CircuitBreakerRegistry,
|
||||
CircuitState,
|
||||
CircuitStats,
|
||||
get_circuit_registry,
|
||||
reset_circuit_registry,
|
||||
)
|
||||
|
||||
|
||||
class TestCircuitState:
|
||||
"""Tests for CircuitState enum."""
|
||||
|
||||
def test_circuit_states(self) -> None:
|
||||
"""Test circuit state values."""
|
||||
assert CircuitState.CLOSED.value == "closed"
|
||||
assert CircuitState.OPEN.value == "open"
|
||||
assert CircuitState.HALF_OPEN.value == "half_open"
|
||||
|
||||
|
||||
class TestCircuitStats:
|
||||
"""Tests for CircuitStats dataclass."""
|
||||
|
||||
def test_default_stats(self) -> None:
|
||||
"""Test default stats values."""
|
||||
stats = CircuitStats()
|
||||
assert stats.failures == 0
|
||||
assert stats.successes == 0
|
||||
assert stats.last_failure_time is None
|
||||
assert stats.last_success_time is None
|
||||
assert stats.half_open_calls == 0
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""Tests for CircuitBreaker class."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""Test circuit breaker initial state."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=5)
|
||||
|
||||
assert cb.name == "test"
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.failure_threshold == 5
|
||||
assert cb.is_available() is True
|
||||
|
||||
def test_state_remains_closed_below_threshold(self) -> None:
|
||||
"""Test circuit stays closed below failure threshold."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
# Record 2 failures (below threshold)
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_failure())
|
||||
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.stats.failures == 2
|
||||
assert cb.is_available() is True
|
||||
|
||||
def test_state_opens_at_threshold(self) -> None:
|
||||
"""Test circuit opens at failure threshold."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
# Record 3 failures (at threshold)
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_failure())
|
||||
|
||||
assert cb.state == CircuitState.OPEN
|
||||
assert cb.is_available() is False
|
||||
|
||||
def test_success_resets_in_closed(self) -> None:
|
||||
"""Test success in closed state records properly."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_success())
|
||||
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.stats.successes == 1
|
||||
assert cb.stats.last_success_time is not None
|
||||
|
||||
def test_half_open_transition(self) -> None:
|
||||
"""Test transition to half-open after recovery timeout."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=1, # 1 second
|
||||
)
|
||||
|
||||
# Open the circuit
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Wait for recovery timeout
|
||||
time.sleep(1.1)
|
||||
|
||||
# State should transition to half-open
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
assert cb.is_available() is True
|
||||
|
||||
def test_half_open_success_closes(self) -> None:
|
||||
"""Test success in half-open closes circuit."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=0, # Immediate recovery for testing
|
||||
)
|
||||
|
||||
# Open and transition to half-open
|
||||
asyncio.run(cb.record_failure())
|
||||
time.sleep(0.1)
|
||||
_ = cb.state # Trigger state check
|
||||
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
# Success should close
|
||||
asyncio.run(cb.record_success())
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
def test_half_open_failure_reopens(self) -> None:
|
||||
"""Test failure in half-open reopens circuit."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=0.05, # Small but non-zero for reliable timing
|
||||
)
|
||||
|
||||
# Open and transition to half-open
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Wait for recovery timeout
|
||||
time.sleep(0.1)
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
# Failure should reopen
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
def test_half_open_call_limit(self) -> None:
|
||||
"""Test half-open call limit."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=0,
|
||||
half_open_max_calls=2,
|
||||
)
|
||||
|
||||
# Open and transition to half-open
|
||||
asyncio.run(cb.record_failure())
|
||||
time.sleep(0.1)
|
||||
_ = cb.state
|
||||
|
||||
assert cb.is_available() is True
|
||||
|
||||
# Simulate calls in half-open
|
||||
cb._stats.half_open_calls = 1
|
||||
assert cb.is_available() is True
|
||||
|
||||
cb._stats.half_open_calls = 2
|
||||
assert cb.is_available() is False
|
||||
|
||||
def test_time_until_recovery(self) -> None:
|
||||
"""Test time until recovery calculation."""
|
||||
cb = CircuitBreaker(
|
||||
name="test",
|
||||
failure_threshold=1,
|
||||
recovery_timeout=60,
|
||||
)
|
||||
|
||||
# Closed circuit has no recovery time
|
||||
assert cb.time_until_recovery() is None
|
||||
|
||||
# Open circuit
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Should have recovery time
|
||||
remaining = cb.time_until_recovery()
|
||||
assert remaining is not None
|
||||
assert 0 <= remaining <= 60
|
||||
|
||||
def test_execute_success(self) -> None:
|
||||
"""Test execute with successful function."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
async def success_func() -> str:
|
||||
return "success"
|
||||
|
||||
result = asyncio.run(cb.execute(success_func))
|
||||
assert result == "success"
|
||||
assert cb.stats.successes == 1
|
||||
|
||||
def test_execute_failure(self) -> None:
|
||||
"""Test execute with failing function."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
|
||||
async def fail_func() -> None:
|
||||
raise ValueError("Error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
asyncio.run(cb.execute(fail_func))
|
||||
|
||||
assert cb.stats.failures == 1
|
||||
|
||||
def test_execute_when_open(self) -> None:
|
||||
"""Test execute raises when circuit is open."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=1)
|
||||
|
||||
# Open the circuit
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
async def success_func() -> str:
|
||||
return "success"
|
||||
|
||||
with pytest.raises(CircuitOpenError) as exc_info:
|
||||
asyncio.run(cb.execute(success_func))
|
||||
|
||||
assert exc_info.value.provider == "test"
|
||||
|
||||
def test_reset(self) -> None:
|
||||
"""Test circuit reset."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=1)
|
||||
|
||||
# Open the circuit
|
||||
asyncio.run(cb.record_failure())
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Reset
|
||||
cb.reset()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.stats.failures == 0
|
||||
assert cb.stats.successes == 0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting circuit to dict."""
|
||||
cb = CircuitBreaker(name="test", failure_threshold=3)
|
||||
asyncio.run(cb.record_failure())
|
||||
asyncio.run(cb.record_success())
|
||||
|
||||
result = cb.to_dict()
|
||||
|
||||
assert result["name"] == "test"
|
||||
assert result["state"] == "closed"
|
||||
assert result["failures"] == 1
|
||||
assert result["successes"] == 1
|
||||
assert result["is_available"] is True
|
||||
|
||||
|
||||
class TestCircuitBreakerRegistry:
|
||||
"""Tests for CircuitBreakerRegistry class."""
|
||||
|
||||
def test_get_circuit_creates_new(self) -> None:
|
||||
"""Test getting a new circuit."""
|
||||
settings = Settings(circuit_failure_threshold=5)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
|
||||
circuit = asyncio.run(registry.get_circuit("anthropic"))
|
||||
|
||||
assert circuit.name == "anthropic"
|
||||
assert circuit.failure_threshold == 5
|
||||
|
||||
def test_get_circuit_returns_same(self) -> None:
|
||||
"""Test getting same circuit twice."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
circuit1 = asyncio.run(registry.get_circuit("openai"))
|
||||
circuit2 = asyncio.run(registry.get_circuit("openai"))
|
||||
|
||||
assert circuit1 is circuit2
|
||||
|
||||
def test_get_circuit_sync(self) -> None:
|
||||
"""Test sync circuit getter."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
circuit = registry.get_circuit_sync("google")
|
||||
assert circuit.name == "google"
|
||||
|
||||
def test_is_available(self) -> None:
|
||||
"""Test checking if circuit is available."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
assert asyncio.run(registry.is_available("test")) is True
|
||||
|
||||
# Open the circuit
|
||||
circuit = asyncio.run(registry.get_circuit("test"))
|
||||
for _ in range(5):
|
||||
asyncio.run(circuit.record_failure())
|
||||
|
||||
assert asyncio.run(registry.is_available("test")) is False
|
||||
|
||||
def test_record_success(self) -> None:
|
||||
"""Test recording success through registry."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
asyncio.run(registry.record_success("test"))
|
||||
|
||||
circuit = asyncio.run(registry.get_circuit("test"))
|
||||
assert circuit.stats.successes == 1
|
||||
|
||||
def test_record_failure(self) -> None:
|
||||
"""Test recording failure through registry."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
asyncio.run(registry.record_failure("test"))
|
||||
|
||||
circuit = asyncio.run(registry.get_circuit("test"))
|
||||
assert circuit.stats.failures == 1
|
||||
|
||||
def test_reset(self) -> None:
|
||||
"""Test resetting a specific circuit."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
# Create and fail a circuit
|
||||
asyncio.run(registry.record_failure("test"))
|
||||
asyncio.run(registry.reset("test"))
|
||||
|
||||
circuit = asyncio.run(registry.get_circuit("test"))
|
||||
assert circuit.stats.failures == 0
|
||||
|
||||
def test_reset_all(self) -> None:
|
||||
"""Test resetting all circuits."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
# Create multiple circuits with failures
|
||||
asyncio.run(registry.record_failure("circuit1"))
|
||||
asyncio.run(registry.record_failure("circuit2"))
|
||||
|
||||
asyncio.run(registry.reset_all())
|
||||
|
||||
circuit1 = asyncio.run(registry.get_circuit("circuit1"))
|
||||
circuit2 = asyncio.run(registry.get_circuit("circuit2"))
|
||||
assert circuit1.stats.failures == 0
|
||||
assert circuit2.stats.failures == 0
|
||||
|
||||
def test_get_all_states(self) -> None:
|
||||
"""Test getting all circuit states."""
|
||||
registry = CircuitBreakerRegistry()
|
||||
|
||||
asyncio.run(registry.get_circuit("circuit1"))
|
||||
asyncio.run(registry.get_circuit("circuit2"))
|
||||
|
||||
states = registry.get_all_states()
|
||||
|
||||
assert "circuit1" in states
|
||||
assert "circuit2" in states
|
||||
assert states["circuit1"]["state"] == "closed"
|
||||
|
||||
def test_get_open_circuits(self) -> None:
|
||||
"""Test getting open circuits."""
|
||||
settings = Settings(circuit_failure_threshold=1)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
|
||||
asyncio.run(registry.get_circuit("healthy"))
|
||||
asyncio.run(registry.record_failure("failing"))
|
||||
|
||||
open_circuits = registry.get_open_circuits()
|
||||
assert "failing" in open_circuits
|
||||
assert "healthy" not in open_circuits
|
||||
|
||||
def test_get_available_circuits(self) -> None:
|
||||
"""Test getting available circuits."""
|
||||
settings = Settings(circuit_failure_threshold=1)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
|
||||
asyncio.run(registry.get_circuit("healthy"))
|
||||
asyncio.run(registry.record_failure("failing"))
|
||||
|
||||
available = registry.get_available_circuits()
|
||||
assert "healthy" in available
|
||||
assert "failing" not in available
|
||||
|
||||
|
||||
class TestGlobalRegistry:
|
||||
"""Tests for global registry functions."""
|
||||
|
||||
def test_get_circuit_registry(self) -> None:
|
||||
"""Test getting global registry."""
|
||||
reset_circuit_registry()
|
||||
registry = get_circuit_registry()
|
||||
assert isinstance(registry, CircuitBreakerRegistry)
|
||||
|
||||
def test_get_circuit_registry_singleton(self) -> None:
|
||||
"""Test registry is singleton."""
|
||||
reset_circuit_registry()
|
||||
registry1 = get_circuit_registry()
|
||||
registry2 = get_circuit_registry()
|
||||
assert registry1 is registry2
|
||||
|
||||
def test_reset_circuit_registry(self) -> None:
|
||||
"""Test resetting global registry."""
|
||||
reset_circuit_registry()
|
||||
registry1 = get_circuit_registry()
|
||||
reset_circuit_registry()
|
||||
registry2 = get_circuit_registry()
|
||||
assert registry1 is not registry2
|
||||
411
mcp-servers/llm-gateway/tests/test_models.py
Normal file
411
mcp-servers/llm-gateway/tests/test_models.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Tests for models module.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from models import (
|
||||
AGENT_TYPE_MODEL_PREFERENCES,
|
||||
MODEL_CONFIGS,
|
||||
MODEL_GROUPS,
|
||||
ChatMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CostRecord,
|
||||
EmbeddingRequest,
|
||||
ModelConfig,
|
||||
ModelGroup,
|
||||
ModelGroupConfig,
|
||||
ModelGroupInfo,
|
||||
ModelInfo,
|
||||
Provider,
|
||||
StreamChunk,
|
||||
UsageReport,
|
||||
UsageStats,
|
||||
)
|
||||
|
||||
|
||||
class TestModelGroup:
|
||||
"""Tests for ModelGroup enum."""
|
||||
|
||||
def test_model_group_values(self) -> None:
|
||||
"""Test model group enum values."""
|
||||
assert ModelGroup.REASONING.value == "reasoning"
|
||||
assert ModelGroup.CODE.value == "code"
|
||||
assert ModelGroup.FAST.value == "fast"
|
||||
assert ModelGroup.VISION.value == "vision"
|
||||
assert ModelGroup.EMBEDDING.value == "embedding"
|
||||
assert ModelGroup.COST_OPTIMIZED.value == "cost_optimized"
|
||||
assert ModelGroup.SELF_HOSTED.value == "self_hosted"
|
||||
|
||||
def test_model_group_from_string(self) -> None:
|
||||
"""Test creating ModelGroup from string."""
|
||||
assert ModelGroup("reasoning") == ModelGroup.REASONING
|
||||
assert ModelGroup("code") == ModelGroup.CODE
|
||||
assert ModelGroup("fast") == ModelGroup.FAST
|
||||
|
||||
def test_model_group_invalid(self) -> None:
|
||||
"""Test invalid model group value."""
|
||||
with pytest.raises(ValueError):
|
||||
ModelGroup("invalid_group")
|
||||
|
||||
|
||||
class TestProvider:
|
||||
"""Tests for Provider enum."""
|
||||
|
||||
def test_provider_values(self) -> None:
|
||||
"""Test provider enum values."""
|
||||
assert Provider.ANTHROPIC.value == "anthropic"
|
||||
assert Provider.OPENAI.value == "openai"
|
||||
assert Provider.GOOGLE.value == "google"
|
||||
assert Provider.ALIBABA.value == "alibaba"
|
||||
assert Provider.DEEPSEEK.value == "deepseek"
|
||||
|
||||
|
||||
class TestModelConfig:
|
||||
"""Tests for ModelConfig dataclass."""
|
||||
|
||||
def test_model_config_creation(self) -> None:
|
||||
"""Test creating a ModelConfig."""
|
||||
config = ModelConfig(
|
||||
name="test-model",
|
||||
litellm_name="provider/test-model",
|
||||
provider=Provider.ANTHROPIC,
|
||||
cost_per_1m_input=10.0,
|
||||
cost_per_1m_output=30.0,
|
||||
context_window=100000,
|
||||
max_output_tokens=4096,
|
||||
supports_vision=True,
|
||||
)
|
||||
|
||||
assert config.name == "test-model"
|
||||
assert config.provider == Provider.ANTHROPIC
|
||||
assert config.cost_per_1m_input == 10.0
|
||||
assert config.supports_vision is True
|
||||
assert config.supports_streaming is True # default
|
||||
|
||||
def test_model_configs_exist(self) -> None:
|
||||
"""Test that model configs are defined."""
|
||||
assert len(MODEL_CONFIGS) > 0
|
||||
assert "claude-opus-4" in MODEL_CONFIGS
|
||||
assert "gpt-4.1" in MODEL_CONFIGS
|
||||
assert "gemini-2.5-pro" in MODEL_CONFIGS
|
||||
|
||||
|
||||
class TestModelGroupConfig:
|
||||
"""Tests for ModelGroupConfig dataclass."""
|
||||
|
||||
def test_model_group_config_creation(self) -> None:
|
||||
"""Test creating a ModelGroupConfig."""
|
||||
config = ModelGroupConfig(
|
||||
primary="model-a",
|
||||
fallbacks=["model-b", "model-c"],
|
||||
description="Test group",
|
||||
)
|
||||
|
||||
assert config.primary == "model-a"
|
||||
assert config.fallbacks == ["model-b", "model-c"]
|
||||
assert config.description == "Test group"
|
||||
|
||||
def test_get_all_models(self) -> None:
|
||||
"""Test getting all models in order."""
|
||||
config = ModelGroupConfig(
|
||||
primary="model-a",
|
||||
fallbacks=["model-b", "model-c"],
|
||||
description="Test group",
|
||||
)
|
||||
|
||||
models = config.get_all_models()
|
||||
assert models == ["model-a", "model-b", "model-c"]
|
||||
|
||||
def test_model_groups_exist(self) -> None:
|
||||
"""Test that model groups are defined."""
|
||||
assert len(MODEL_GROUPS) > 0
|
||||
assert ModelGroup.REASONING in MODEL_GROUPS
|
||||
assert ModelGroup.CODE in MODEL_GROUPS
|
||||
assert ModelGroup.FAST in MODEL_GROUPS
|
||||
|
||||
|
||||
class TestAgentTypePreferences:
|
||||
"""Tests for agent type model preferences."""
|
||||
|
||||
def test_agent_preferences_exist(self) -> None:
|
||||
"""Test that agent preferences are defined."""
|
||||
assert len(AGENT_TYPE_MODEL_PREFERENCES) > 0
|
||||
assert "product_owner" in AGENT_TYPE_MODEL_PREFERENCES
|
||||
assert "software_engineer" in AGENT_TYPE_MODEL_PREFERENCES
|
||||
|
||||
def test_agent_preference_values(self) -> None:
|
||||
"""Test agent preference values."""
|
||||
assert AGENT_TYPE_MODEL_PREFERENCES["product_owner"] == ModelGroup.REASONING
|
||||
assert AGENT_TYPE_MODEL_PREFERENCES["software_engineer"] == ModelGroup.CODE
|
||||
assert AGENT_TYPE_MODEL_PREFERENCES["devops_engineer"] == ModelGroup.FAST
|
||||
|
||||
|
||||
class TestChatMessage:
|
||||
"""Tests for ChatMessage model."""
|
||||
|
||||
def test_chat_message_creation(self) -> None:
|
||||
"""Test creating a ChatMessage."""
|
||||
msg = ChatMessage(role="user", content="Hello")
|
||||
assert msg.role == "user"
|
||||
assert msg.content == "Hello"
|
||||
assert msg.name is None
|
||||
assert msg.tool_calls is None
|
||||
|
||||
def test_chat_message_with_optional(self) -> None:
|
||||
"""Test ChatMessage with optional fields."""
|
||||
msg = ChatMessage(
|
||||
role="assistant",
|
||||
content="Response",
|
||||
name="assistant_1",
|
||||
tool_calls=[{"id": "call_1", "function": {"name": "test"}}],
|
||||
)
|
||||
assert msg.name == "assistant_1"
|
||||
assert msg.tool_calls is not None
|
||||
|
||||
def test_chat_message_list_content(self) -> None:
|
||||
"""Test ChatMessage with list content (for images)."""
|
||||
msg = ChatMessage(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "http://example.com/img.jpg"},
|
||||
},
|
||||
],
|
||||
)
|
||||
assert isinstance(msg.content, list)
|
||||
|
||||
|
||||
class TestCompletionRequest:
|
||||
"""Tests for CompletionRequest model."""
|
||||
|
||||
def test_completion_request_minimal(self) -> None:
|
||||
"""Test minimal CompletionRequest."""
|
||||
req = CompletionRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
messages=[ChatMessage(role="user", content="Hi")],
|
||||
)
|
||||
|
||||
assert req.project_id == "proj-123"
|
||||
assert req.agent_id == "agent-456"
|
||||
assert len(req.messages) == 1
|
||||
assert req.model_group == ModelGroup.REASONING # default
|
||||
assert req.max_tokens == 4096 # default
|
||||
assert req.temperature == 0.7 # default
|
||||
|
||||
def test_completion_request_full(self) -> None:
|
||||
"""Test full CompletionRequest."""
|
||||
req = CompletionRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
messages=[ChatMessage(role="user", content="Hi")],
|
||||
model_group=ModelGroup.CODE,
|
||||
model_override="claude-sonnet-4",
|
||||
max_tokens=8192,
|
||||
temperature=0.5,
|
||||
stream=True,
|
||||
session_id="session-789",
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
|
||||
assert req.model_group == ModelGroup.CODE
|
||||
assert req.model_override == "claude-sonnet-4"
|
||||
assert req.max_tokens == 8192
|
||||
assert req.stream is True
|
||||
|
||||
def test_completion_request_validation(self) -> None:
|
||||
"""Test CompletionRequest validation."""
|
||||
with pytest.raises(ValueError):
|
||||
CompletionRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
messages=[ChatMessage(role="user", content="Hi")],
|
||||
max_tokens=0, # Invalid
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CompletionRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
messages=[ChatMessage(role="user", content="Hi")],
|
||||
temperature=-0.1, # Invalid
|
||||
)
|
||||
|
||||
|
||||
class TestUsageStats:
|
||||
"""Tests for UsageStats model."""
|
||||
|
||||
def test_usage_stats_default(self) -> None:
|
||||
"""Test default UsageStats."""
|
||||
stats = UsageStats()
|
||||
assert stats.prompt_tokens == 0
|
||||
assert stats.completion_tokens == 0
|
||||
assert stats.total_tokens == 0
|
||||
assert stats.cost_usd == 0.0
|
||||
|
||||
def test_usage_stats_custom(self) -> None:
|
||||
"""Test custom UsageStats."""
|
||||
stats = UsageStats(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
assert stats.prompt_tokens == 100
|
||||
assert stats.total_tokens == 150
|
||||
|
||||
def test_usage_stats_from_response(self) -> None:
|
||||
"""Test creating UsageStats from response."""
|
||||
config = MODEL_CONFIGS["claude-opus-4"]
|
||||
stats = UsageStats.from_response(
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
model_config=config,
|
||||
)
|
||||
|
||||
assert stats.prompt_tokens == 1000
|
||||
assert stats.completion_tokens == 500
|
||||
assert stats.total_tokens == 1500
|
||||
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
|
||||
assert stats.cost_usd == pytest.approx(0.0525, rel=0.01)
|
||||
|
||||
|
||||
class TestCompletionResponse:
|
||||
"""Tests for CompletionResponse model."""
|
||||
|
||||
def test_completion_response_creation(self) -> None:
|
||||
"""Test creating a CompletionResponse."""
|
||||
response = CompletionResponse(
|
||||
id="resp-123",
|
||||
model="claude-opus-4",
|
||||
provider="anthropic",
|
||||
content="Hello, world!",
|
||||
)
|
||||
|
||||
assert response.id == "resp-123"
|
||||
assert response.model == "claude-opus-4"
|
||||
assert response.provider == "anthropic"
|
||||
assert response.content == "Hello, world!"
|
||||
assert response.finish_reason == "stop"
|
||||
|
||||
|
||||
class TestStreamChunk:
|
||||
"""Tests for StreamChunk model."""
|
||||
|
||||
def test_stream_chunk_creation(self) -> None:
|
||||
"""Test creating a StreamChunk."""
|
||||
chunk = StreamChunk(id="chunk-1", delta="Hello")
|
||||
assert chunk.id == "chunk-1"
|
||||
assert chunk.delta == "Hello"
|
||||
assert chunk.finish_reason is None
|
||||
|
||||
def test_stream_chunk_final(self) -> None:
|
||||
"""Test final StreamChunk."""
|
||||
chunk = StreamChunk(
|
||||
id="chunk-last",
|
||||
delta="",
|
||||
finish_reason="stop",
|
||||
usage=UsageStats(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
assert chunk.finish_reason == "stop"
|
||||
assert chunk.usage is not None
|
||||
|
||||
|
||||
class TestEmbeddingRequest:
|
||||
"""Tests for EmbeddingRequest model."""
|
||||
|
||||
def test_embedding_request_creation(self) -> None:
|
||||
"""Test creating an EmbeddingRequest."""
|
||||
req = EmbeddingRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
texts=["Hello", "World"],
|
||||
)
|
||||
|
||||
assert req.project_id == "proj-123"
|
||||
assert len(req.texts) == 2
|
||||
assert req.model == "text-embedding-3-large" # default
|
||||
|
||||
def test_embedding_request_validation(self) -> None:
|
||||
"""Test EmbeddingRequest validation."""
|
||||
with pytest.raises(ValueError):
|
||||
EmbeddingRequest(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
texts=[], # Invalid - must have at least 1
|
||||
)
|
||||
|
||||
|
||||
class TestCostRecord:
|
||||
"""Tests for CostRecord dataclass."""
|
||||
|
||||
def test_cost_record_creation(self) -> None:
|
||||
"""Test creating a CostRecord."""
|
||||
record = CostRecord(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
model="claude-opus-4",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
|
||||
assert record.project_id == "proj-123"
|
||||
assert record.cost_usd == 0.01
|
||||
assert record.timestamp is not None
|
||||
|
||||
|
||||
class TestUsageReport:
|
||||
"""Tests for UsageReport model."""
|
||||
|
||||
def test_usage_report_creation(self) -> None:
|
||||
"""Test creating a UsageReport."""
|
||||
now = datetime.now(UTC)
|
||||
report = UsageReport(
|
||||
entity_id="proj-123",
|
||||
entity_type="project",
|
||||
period="day",
|
||||
period_start=now,
|
||||
period_end=now,
|
||||
)
|
||||
|
||||
assert report.entity_id == "proj-123"
|
||||
assert report.entity_type == "project"
|
||||
assert report.total_requests == 0
|
||||
assert report.total_cost_usd == 0.0
|
||||
|
||||
|
||||
class TestModelInfo:
|
||||
"""Tests for ModelInfo model."""
|
||||
|
||||
def test_model_info_from_config(self) -> None:
|
||||
"""Test creating ModelInfo from ModelConfig."""
|
||||
config = MODEL_CONFIGS["claude-opus-4"]
|
||||
info = ModelInfo.from_config(config, available=True)
|
||||
|
||||
assert info.name == "claude-opus-4"
|
||||
assert info.provider == "anthropic"
|
||||
assert info.available is True
|
||||
assert info.supports_vision is True
|
||||
|
||||
|
||||
class TestModelGroupInfo:
|
||||
"""Tests for ModelGroupInfo model."""
|
||||
|
||||
def test_model_group_info_creation(self) -> None:
|
||||
"""Test creating ModelGroupInfo."""
|
||||
info = ModelGroupInfo(
|
||||
name="reasoning",
|
||||
description="Complex analysis",
|
||||
primary_model="claude-opus-4",
|
||||
fallback_models=["gpt-4.1"],
|
||||
)
|
||||
|
||||
assert info.name == "reasoning"
|
||||
assert len(info.fallback_models) == 1
|
||||
308
mcp-servers/llm-gateway/tests/test_providers.py
Normal file
308
mcp-servers/llm-gateway/tests/test_providers.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Tests for providers module.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from models import MODEL_CONFIGS, ModelGroup, Provider
|
||||
from providers import (
|
||||
LLMProvider,
|
||||
build_fallback_config,
|
||||
build_model_list,
|
||||
configure_litellm,
|
||||
get_available_model_groups,
|
||||
get_available_models,
|
||||
get_provider,
|
||||
reset_provider,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_settings() -> Settings:
|
||||
"""Settings with all providers configured."""
|
||||
return Settings(
|
||||
anthropic_api_key="test-anthropic-key",
|
||||
openai_api_key="test-openai-key",
|
||||
google_api_key="test-google-key",
|
||||
alibaba_api_key="test-alibaba-key",
|
||||
deepseek_api_key="test-deepseek-key",
|
||||
litellm_timeout=60,
|
||||
litellm_cache_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def partial_settings() -> Settings:
|
||||
"""Settings with only some providers configured."""
|
||||
return Settings(
|
||||
anthropic_api_key="test-anthropic-key",
|
||||
openai_api_key=None,
|
||||
google_api_key=None,
|
||||
alibaba_api_key=None,
|
||||
deepseek_api_key=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_settings() -> Settings:
|
||||
"""Settings with no providers configured."""
|
||||
return Settings(
|
||||
anthropic_api_key=None,
|
||||
openai_api_key=None,
|
||||
google_api_key=None,
|
||||
alibaba_api_key=None,
|
||||
deepseek_api_key=None,
|
||||
)
|
||||
|
||||
|
||||
class TestConfigureLiteLLM:
|
||||
"""Tests for configure_litellm function."""
|
||||
|
||||
def test_sets_api_keys(self, full_settings: Settings) -> None:
|
||||
"""Test that API keys are set in environment."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
configure_litellm(full_settings)
|
||||
|
||||
assert os.environ.get("ANTHROPIC_API_KEY") == "test-anthropic-key"
|
||||
assert os.environ.get("OPENAI_API_KEY") == "test-openai-key"
|
||||
assert os.environ.get("GEMINI_API_KEY") == "test-google-key"
|
||||
|
||||
def test_skips_none_keys(self, partial_settings: Settings) -> None:
|
||||
"""Test that None keys are not set."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
configure_litellm(partial_settings)
|
||||
|
||||
assert os.environ.get("ANTHROPIC_API_KEY") == "test-anthropic-key"
|
||||
assert "OPENAI_API_KEY" not in os.environ
|
||||
|
||||
|
||||
class TestBuildModelList:
|
||||
"""Tests for build_model_list function."""
|
||||
|
||||
def test_build_with_all_providers(self, full_settings: Settings) -> None:
|
||||
"""Test building model list with all providers."""
|
||||
model_list = build_model_list(full_settings)
|
||||
|
||||
assert len(model_list) > 0
|
||||
|
||||
# Check structure
|
||||
for entry in model_list:
|
||||
assert "model_name" in entry
|
||||
assert "litellm_params" in entry
|
||||
assert "model" in entry["litellm_params"]
|
||||
assert "timeout" in entry["litellm_params"]
|
||||
|
||||
def test_build_with_partial_providers(self, partial_settings: Settings) -> None:
|
||||
"""Test building model list with partial providers."""
|
||||
model_list = build_model_list(partial_settings)
|
||||
|
||||
# Should only include Anthropic models
|
||||
providers = set()
|
||||
for entry in model_list:
|
||||
model_name = entry["model_name"]
|
||||
config = MODEL_CONFIGS.get(model_name)
|
||||
if config:
|
||||
providers.add(config.provider)
|
||||
|
||||
assert Provider.ANTHROPIC in providers
|
||||
assert Provider.OPENAI not in providers
|
||||
|
||||
def test_build_with_no_providers(self, empty_settings: Settings) -> None:
|
||||
"""Test building model list with no providers."""
|
||||
model_list = build_model_list(empty_settings)
|
||||
|
||||
assert len(model_list) == 0
|
||||
|
||||
def test_build_includes_timeout(self, full_settings: Settings) -> None:
|
||||
"""Test that model entries include timeout."""
|
||||
model_list = build_model_list(full_settings)
|
||||
|
||||
for entry in model_list:
|
||||
assert entry["litellm_params"]["timeout"] == 60
|
||||
|
||||
|
||||
class TestBuildFallbackConfig:
|
||||
"""Tests for build_fallback_config function."""
|
||||
|
||||
def test_build_fallbacks_full(self, full_settings: Settings) -> None:
|
||||
"""Test building fallback config with all providers."""
|
||||
fallbacks = build_fallback_config(full_settings)
|
||||
|
||||
assert len(fallbacks) > 0
|
||||
|
||||
# Primary models should have fallbacks
|
||||
for _primary, chain in fallbacks.items():
|
||||
assert isinstance(chain, list)
|
||||
assert len(chain) > 0
|
||||
|
||||
def test_build_fallbacks_partial(self, partial_settings: Settings) -> None:
|
||||
"""Test building fallback config with partial providers."""
|
||||
fallbacks = build_fallback_config(partial_settings)
|
||||
|
||||
# With only Anthropic, there should be no fallbacks
|
||||
# (fallbacks require at least 2 available models)
|
||||
for primary, chain in fallbacks.items():
|
||||
# All models in chain should be from Anthropic
|
||||
for model in [primary] + chain:
|
||||
config = MODEL_CONFIGS.get(model)
|
||||
if config:
|
||||
assert config.provider == Provider.ANTHROPIC
|
||||
|
||||
|
||||
class TestGetAvailableModels:
|
||||
"""Tests for get_available_models function."""
|
||||
|
||||
def test_get_available_full(self, full_settings: Settings) -> None:
|
||||
"""Test getting available models with all providers."""
|
||||
models = get_available_models(full_settings)
|
||||
|
||||
assert len(models) > 0
|
||||
assert "claude-opus-4" in models
|
||||
assert "gpt-4.1" in models
|
||||
|
||||
def test_get_available_partial(self, partial_settings: Settings) -> None:
|
||||
"""Test getting available models with partial providers."""
|
||||
models = get_available_models(partial_settings)
|
||||
|
||||
assert "claude-opus-4" in models
|
||||
assert "gpt-4.1" not in models
|
||||
|
||||
def test_get_available_empty(self, empty_settings: Settings) -> None:
|
||||
"""Test getting available models with no providers."""
|
||||
models = get_available_models(empty_settings)
|
||||
|
||||
assert len(models) == 0
|
||||
|
||||
|
||||
class TestGetAvailableModelGroups:
|
||||
"""Tests for get_available_model_groups function."""
|
||||
|
||||
def test_get_groups_full(self, full_settings: Settings) -> None:
|
||||
"""Test getting groups with all providers."""
|
||||
groups = get_available_model_groups(full_settings)
|
||||
|
||||
assert len(groups) == len(ModelGroup)
|
||||
assert ModelGroup.REASONING in groups
|
||||
assert len(groups[ModelGroup.REASONING]) > 0
|
||||
|
||||
def test_get_groups_partial(self, partial_settings: Settings) -> None:
|
||||
"""Test getting groups with partial providers."""
|
||||
groups = get_available_model_groups(partial_settings)
|
||||
|
||||
# Only Anthropic models should be available
|
||||
for _group, models in groups.items():
|
||||
for model in models:
|
||||
config = MODEL_CONFIGS.get(model)
|
||||
if config:
|
||||
assert config.provider == Provider.ANTHROPIC
|
||||
|
||||
|
||||
class TestLLMProvider:
|
||||
"""Tests for LLMProvider class."""
|
||||
|
||||
def test_initialization(self, full_settings: Settings) -> None:
|
||||
"""Test provider initialization."""
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
|
||||
assert provider._initialized is False
|
||||
assert provider._router is None
|
||||
|
||||
def test_initialize(self, full_settings: Settings) -> None:
|
||||
"""Test provider initialize."""
|
||||
with patch("providers.Router") as mock_router:
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
provider.initialize()
|
||||
|
||||
assert provider._initialized is True
|
||||
mock_router.assert_called_once()
|
||||
|
||||
def test_initialize_idempotent(self, full_settings: Settings) -> None:
|
||||
"""Test that initialize is idempotent."""
|
||||
with patch("providers.Router") as mock_router:
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
provider.initialize()
|
||||
provider.initialize()
|
||||
|
||||
# Should only be called once
|
||||
assert mock_router.call_count == 1
|
||||
|
||||
def test_initialize_no_providers(self, empty_settings: Settings) -> None:
|
||||
"""Test initialization with no providers."""
|
||||
provider = LLMProvider(settings=empty_settings)
|
||||
provider.initialize()
|
||||
|
||||
assert provider._initialized is True
|
||||
assert provider._router is None
|
||||
|
||||
def test_router_property(self, full_settings: Settings) -> None:
|
||||
"""Test router property triggers initialization."""
|
||||
with patch("providers.Router"):
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
_ = provider.router
|
||||
|
||||
assert provider._initialized is True
|
||||
|
||||
def test_is_available(self, full_settings: Settings) -> None:
|
||||
"""Test is_available property."""
|
||||
with patch("providers.Router"):
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
assert provider.is_available is True
|
||||
|
||||
def test_is_not_available(self, empty_settings: Settings) -> None:
|
||||
"""Test is_available when no providers."""
|
||||
provider = LLMProvider(settings=empty_settings)
|
||||
assert provider.is_available is False
|
||||
|
||||
def test_get_model_config(self, full_settings: Settings) -> None:
|
||||
"""Test getting model config."""
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
|
||||
config = provider.get_model_config("claude-opus-4")
|
||||
assert config is not None
|
||||
assert config.name == "claude-opus-4"
|
||||
|
||||
assert provider.get_model_config("nonexistent") is None
|
||||
|
||||
def test_get_available_models(self, full_settings: Settings) -> None:
|
||||
"""Test getting available models."""
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
models = provider.get_available_models()
|
||||
|
||||
assert "claude-opus-4" in models
|
||||
assert "gpt-4.1" in models
|
||||
|
||||
def test_is_model_available(self, full_settings: Settings) -> None:
|
||||
"""Test checking model availability."""
|
||||
provider = LLMProvider(settings=full_settings)
|
||||
|
||||
assert provider.is_model_available("claude-opus-4") is True
|
||||
assert provider.is_model_available("nonexistent") is False
|
||||
|
||||
|
||||
class TestGlobalProvider:
|
||||
"""Tests for global provider functions."""
|
||||
|
||||
def test_get_provider(self) -> None:
|
||||
"""Test getting global provider."""
|
||||
reset_provider()
|
||||
provider = get_provider()
|
||||
assert isinstance(provider, LLMProvider)
|
||||
|
||||
def test_get_provider_singleton(self) -> None:
|
||||
"""Test provider is singleton."""
|
||||
reset_provider()
|
||||
provider1 = get_provider()
|
||||
provider2 = get_provider()
|
||||
assert provider1 is provider2
|
||||
|
||||
def test_reset_provider(self) -> None:
|
||||
"""Test resetting global provider."""
|
||||
reset_provider()
|
||||
provider1 = get_provider()
|
||||
reset_provider()
|
||||
provider2 = get_provider()
|
||||
assert provider1 is not provider2
|
||||
244
mcp-servers/llm-gateway/tests/test_routing.py
Normal file
244
mcp-servers/llm-gateway/tests/test_routing.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Tests for routing module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from exceptions import (
|
||||
AllProvidersFailedError,
|
||||
InvalidModelError,
|
||||
InvalidModelGroupError,
|
||||
ModelNotAvailableError,
|
||||
)
|
||||
from failover import CircuitBreakerRegistry, reset_circuit_registry
|
||||
from models import ModelGroup
|
||||
from providers import reset_provider
|
||||
from routing import ModelRouter, get_model_router, reset_model_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router_settings() -> Settings:
|
||||
"""Settings for routing tests."""
|
||||
return Settings(
|
||||
anthropic_api_key="test-key",
|
||||
openai_api_key="test-key",
|
||||
google_api_key="test-key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(router_settings: Settings) -> ModelRouter:
|
||||
"""Create model router for testing."""
|
||||
reset_circuit_registry()
|
||||
reset_model_router()
|
||||
reset_provider()
|
||||
registry = CircuitBreakerRegistry(settings=router_settings)
|
||||
return ModelRouter(settings=router_settings, circuit_registry=registry)
|
||||
|
||||
|
||||
class TestModelRouter:
|
||||
"""Tests for ModelRouter class."""
|
||||
|
||||
def test_parse_model_group_valid(self, router: ModelRouter) -> None:
|
||||
"""Test parsing valid model groups."""
|
||||
assert router.parse_model_group("reasoning") == ModelGroup.REASONING
|
||||
assert router.parse_model_group("code") == ModelGroup.CODE
|
||||
assert router.parse_model_group("fast") == ModelGroup.FAST
|
||||
assert router.parse_model_group("REASONING") == ModelGroup.REASONING
|
||||
|
||||
def test_parse_model_group_aliases(self, router: ModelRouter) -> None:
|
||||
"""Test parsing model group aliases."""
|
||||
assert router.parse_model_group("high-reasoning") == ModelGroup.REASONING
|
||||
assert router.parse_model_group("high_reasoning") == ModelGroup.REASONING
|
||||
assert router.parse_model_group("code-generation") == ModelGroup.CODE
|
||||
assert router.parse_model_group("fast-response") == ModelGroup.FAST
|
||||
|
||||
def test_parse_model_group_invalid(self, router: ModelRouter) -> None:
|
||||
"""Test parsing invalid model group."""
|
||||
with pytest.raises(InvalidModelGroupError) as exc_info:
|
||||
router.parse_model_group("invalid_group")
|
||||
|
||||
assert exc_info.value.model_group == "invalid_group"
|
||||
assert exc_info.value.available_groups is not None
|
||||
|
||||
def test_get_model_config_valid(self, router: ModelRouter) -> None:
|
||||
"""Test getting valid model config."""
|
||||
config = router.get_model_config("claude-opus-4")
|
||||
assert config.name == "claude-opus-4"
|
||||
assert config.provider.value == "anthropic"
|
||||
|
||||
def test_get_model_config_invalid(self, router: ModelRouter) -> None:
|
||||
"""Test getting invalid model config."""
|
||||
with pytest.raises(InvalidModelError) as exc_info:
|
||||
router.get_model_config("nonexistent-model")
|
||||
|
||||
assert exc_info.value.model == "nonexistent-model"
|
||||
|
||||
def test_get_preferred_group_for_agent(self, router: ModelRouter) -> None:
|
||||
"""Test getting preferred group for agent types."""
|
||||
assert (
|
||||
router.get_preferred_group_for_agent("product_owner")
|
||||
== ModelGroup.REASONING
|
||||
)
|
||||
assert (
|
||||
router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
|
||||
)
|
||||
assert (
|
||||
router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
|
||||
)
|
||||
|
||||
def test_get_preferred_group_unknown_agent(self, router: ModelRouter) -> None:
|
||||
"""Test getting preferred group for unknown agent."""
|
||||
# Should default to REASONING
|
||||
assert (
|
||||
router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
|
||||
)
|
||||
|
||||
def test_select_model_by_group(self, router: ModelRouter) -> None:
|
||||
"""Test selecting model by group."""
|
||||
model_name, config = asyncio.run(
|
||||
router.select_model(model_group=ModelGroup.REASONING)
|
||||
)
|
||||
|
||||
assert model_name == "claude-opus-4"
|
||||
assert config.provider.value == "anthropic"
|
||||
|
||||
def test_select_model_by_group_string(self, router: ModelRouter) -> None:
|
||||
"""Test selecting model by group string."""
|
||||
model_name, config = asyncio.run(router.select_model(model_group="code"))
|
||||
|
||||
assert model_name == "claude-sonnet-4"
|
||||
|
||||
def test_select_model_with_override(self, router: ModelRouter) -> None:
|
||||
"""Test selecting specific model override."""
|
||||
model_name, config = asyncio.run(
|
||||
router.select_model(
|
||||
model_group="reasoning",
|
||||
model_override="gpt-4.1",
|
||||
)
|
||||
)
|
||||
|
||||
assert model_name == "gpt-4.1"
|
||||
assert config.provider.value == "openai"
|
||||
|
||||
def test_select_model_override_invalid(self, router: ModelRouter) -> None:
|
||||
"""Test selecting invalid model override."""
|
||||
with pytest.raises(InvalidModelError):
|
||||
asyncio.run(
|
||||
router.select_model(
|
||||
model_group="reasoning",
|
||||
model_override="nonexistent-model",
|
||||
)
|
||||
)
|
||||
|
||||
def test_select_model_override_unavailable(self, router: ModelRouter) -> None: # noqa: ARG002
|
||||
"""Test selecting unavailable model override."""
|
||||
# Create router without Alibaba key
|
||||
settings = Settings(
|
||||
anthropic_api_key="test-key",
|
||||
alibaba_api_key=None,
|
||||
)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
|
||||
|
||||
with pytest.raises(ModelNotAvailableError):
|
||||
asyncio.run(
|
||||
limited_router.select_model(
|
||||
model_group="reasoning",
|
||||
model_override="qwen-max",
|
||||
)
|
||||
)
|
||||
|
||||
def test_select_model_fallback_on_circuit_open(
|
||||
self,
|
||||
router: ModelRouter,
|
||||
) -> None:
|
||||
"""Test fallback when primary circuit is open."""
|
||||
# Open circuit for anthropic
|
||||
circuit = router._circuit_registry.get_circuit_sync("anthropic")
|
||||
for _ in range(5):
|
||||
asyncio.run(circuit.record_failure())
|
||||
|
||||
# Should fall back to OpenAI
|
||||
model_name, config = asyncio.run(
|
||||
router.select_model(model_group=ModelGroup.REASONING)
|
||||
)
|
||||
|
||||
assert model_name == "gpt-4.1"
|
||||
assert config.provider.value == "openai"
|
||||
|
||||
def test_select_model_all_unavailable(self) -> None:
|
||||
"""Test when all providers are unavailable."""
|
||||
settings = Settings(
|
||||
anthropic_api_key=None,
|
||||
openai_api_key=None,
|
||||
google_api_key=None,
|
||||
)
|
||||
registry = CircuitBreakerRegistry(settings=settings)
|
||||
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
|
||||
|
||||
with pytest.raises(AllProvidersFailedError) as exc_info:
|
||||
asyncio.run(limited_router.select_model(model_group=ModelGroup.REASONING))
|
||||
|
||||
assert exc_info.value.model_group == "reasoning"
|
||||
assert len(exc_info.value.attempted_models) > 0
|
||||
|
||||
def test_get_available_models_for_group(self, router: ModelRouter) -> None:
|
||||
"""Test getting available models for a group."""
|
||||
models = asyncio.run(
|
||||
router.get_available_models_for_group(ModelGroup.REASONING)
|
||||
)
|
||||
|
||||
assert len(models) > 0
|
||||
# Should be (name, config, available) tuples
|
||||
for name, config, _available in models:
|
||||
assert isinstance(name, str)
|
||||
assert config is not None
|
||||
|
||||
def test_get_available_models_for_group_string(self, router: ModelRouter) -> None:
|
||||
"""Test getting available models with string group."""
|
||||
models = asyncio.run(router.get_available_models_for_group("code"))
|
||||
|
||||
assert len(models) > 0
|
||||
|
||||
def test_get_available_models_invalid_group(self, router: ModelRouter) -> None:
|
||||
"""Test getting models for invalid group."""
|
||||
with pytest.raises(InvalidModelGroupError):
|
||||
asyncio.run(router.get_available_models_for_group("invalid"))
|
||||
|
||||
def test_get_all_model_groups(self, router: ModelRouter) -> None:
|
||||
"""Test getting all model groups info."""
|
||||
groups = router.get_all_model_groups()
|
||||
|
||||
assert len(groups) == len(ModelGroup)
|
||||
assert "reasoning" in groups
|
||||
assert "code" in groups
|
||||
assert groups["reasoning"]["primary"] == "claude-opus-4"
|
||||
|
||||
|
||||
class TestGlobalRouter:
|
||||
"""Tests for global router functions."""
|
||||
|
||||
def test_get_model_router(self) -> None:
|
||||
"""Test getting global router."""
|
||||
reset_model_router()
|
||||
router = get_model_router()
|
||||
assert isinstance(router, ModelRouter)
|
||||
|
||||
def test_get_model_router_singleton(self) -> None:
|
||||
"""Test router is singleton."""
|
||||
reset_model_router()
|
||||
router1 = get_model_router()
|
||||
router2 = get_model_router()
|
||||
assert router1 is router2
|
||||
|
||||
def test_reset_model_router(self) -> None:
|
||||
"""Test resetting global router."""
|
||||
reset_model_router()
|
||||
router1 = get_model_router()
|
||||
reset_model_router()
|
||||
router2 = get_model_router()
|
||||
assert router1 is not router2
|
||||
415
mcp-servers/llm-gateway/tests/test_server.py
Normal file
415
mcp-servers/llm-gateway/tests/test_server.py
Normal file
@@ -0,0 +1,415 @@
|
||||
"""
|
||||
Tests for server module.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from config import Settings
|
||||
from models import ModelGroup
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings() -> Settings:
|
||||
"""Test settings with mock API keys."""
|
||||
return Settings(
|
||||
anthropic_api_key="test-anthropic-key",
|
||||
openai_api_key="test-openai-key",
|
||||
google_api_key="test-google-key",
|
||||
cost_tracking_enabled=False, # Disable for most tests
|
||||
litellm_cache_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(test_settings: Settings) -> TestClient:
|
||||
"""Create test client with mocked dependencies."""
|
||||
with (
|
||||
patch("server.get_settings", return_value=test_settings),
|
||||
patch("server.get_provider") as mock_provider,
|
||||
):
|
||||
mock_provider.return_value = MagicMock()
|
||||
mock_provider.return_value.is_available = True
|
||||
mock_provider.return_value.router = MagicMock()
|
||||
mock_provider.return_value.get_available_models.return_value = {}
|
||||
|
||||
from server import app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
def test_health_check(self, test_client: TestClient) -> None:
|
||||
"""Test health check returns healthy status."""
|
||||
with patch("server.get_settings") as mock_settings:
|
||||
mock_settings.return_value = Settings(
|
||||
anthropic_api_key="test-key",
|
||||
)
|
||||
with patch("server.get_provider") as mock_provider:
|
||||
mock_provider.return_value = MagicMock()
|
||||
mock_provider.return_value.is_available = True
|
||||
|
||||
response = test_client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["service"] == "llm-gateway"
|
||||
|
||||
|
||||
class TestToolDiscoveryEndpoint:
|
||||
"""Tests for tool discovery endpoint."""
|
||||
|
||||
def test_list_tools(self, test_client: TestClient) -> None:
|
||||
"""Test listing available tools."""
|
||||
response = test_client.get("/mcp/tools")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
assert len(data["tools"]) == 4 # 4 tools defined
|
||||
|
||||
tool_names = [t["name"] for t in data["tools"]]
|
||||
assert "chat_completion" in tool_names
|
||||
assert "list_models" in tool_names
|
||||
assert "get_usage" in tool_names
|
||||
assert "count_tokens" in tool_names
|
||||
|
||||
def test_tool_has_schema(self, test_client: TestClient) -> None:
|
||||
"""Test that tools have input schemas."""
|
||||
response = test_client.get("/mcp/tools")
|
||||
data = response.json()
|
||||
|
||||
for tool in data["tools"]:
|
||||
assert "inputSchema" in tool
|
||||
assert "type" in tool["inputSchema"]
|
||||
assert tool["inputSchema"]["type"] == "object"
|
||||
|
||||
|
||||
class TestJSONRPCEndpoint:
|
||||
"""Tests for JSON-RPC endpoint."""
|
||||
|
||||
def test_invalid_jsonrpc_version(self, test_client: TestClient) -> None:
|
||||
"""Test invalid JSON-RPC version."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "1.0", # Invalid
|
||||
"method": "tools/list",
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["code"] == -32600
|
||||
|
||||
def test_tools_list(self, test_client: TestClient) -> None:
|
||||
"""Test tools/list method."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/list",
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "result" in data
|
||||
assert "tools" in data["result"]
|
||||
|
||||
def test_unknown_method(self, test_client: TestClient) -> None:
|
||||
"""Test unknown method."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "unknown/method",
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["code"] == -32601
|
||||
|
||||
def test_unknown_tool(self, test_client: TestClient) -> None:
|
||||
"""Test unknown tool."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "unknown_tool",
|
||||
"arguments": {},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert "Unknown tool" in data["error"]["message"]
|
||||
|
||||
|
||||
class TestCountTokensTool:
|
||||
"""Tests for count_tokens tool."""
|
||||
|
||||
def test_count_tokens(self, test_client: TestClient) -> None:
|
||||
"""Test counting tokens."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "count_tokens",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"text": "Hello, world!",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "result" in data
|
||||
|
||||
def test_count_tokens_with_model(self, test_client: TestClient) -> None:
|
||||
"""Test counting tokens with specific model."""
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "count_tokens",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"text": "Hello, world!",
|
||||
"model": "gpt-4",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestListModelsTool:
|
||||
"""Tests for list_models tool."""
|
||||
|
||||
def test_list_all_models(self, test_client: TestClient) -> None:
|
||||
"""Test listing all models."""
|
||||
with patch("server.get_model_router") as mock_router:
|
||||
mock_router.return_value = MagicMock()
|
||||
mock_router.return_value.get_available_models_for_group = AsyncMock(
|
||||
return_value=[]
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "list_models",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_list_models_by_group(self, test_client: TestClient) -> None:
|
||||
"""Test listing models by group."""
|
||||
with patch("server.get_model_router") as mock_router:
|
||||
mock_router.return_value = MagicMock()
|
||||
mock_router.return_value.parse_model_group.return_value = (
|
||||
ModelGroup.REASONING
|
||||
)
|
||||
mock_router.return_value.get_available_models_for_group = AsyncMock(
|
||||
return_value=[]
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "list_models",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"model_group": "reasoning",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestGetUsageTool:
|
||||
"""Tests for get_usage tool."""
|
||||
|
||||
def test_get_usage(self, test_client: TestClient) -> None:
|
||||
"""Test getting usage."""
|
||||
with patch("server.get_cost_tracker") as mock_tracker:
|
||||
mock_report = MagicMock()
|
||||
mock_report.total_requests = 10
|
||||
mock_report.total_tokens = 1000
|
||||
mock_report.total_cost_usd = 0.50
|
||||
mock_report.by_model = {}
|
||||
mock_report.period_start.isoformat.return_value = "2024-01-01T00:00:00"
|
||||
mock_report.period_end.isoformat.return_value = "2024-01-02T00:00:00"
|
||||
|
||||
mock_tracker.return_value = MagicMock()
|
||||
mock_tracker.return_value.get_project_usage = AsyncMock(
|
||||
return_value=mock_report
|
||||
)
|
||||
mock_tracker.return_value.get_agent_usage = AsyncMock(
|
||||
return_value=mock_report
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "get_usage",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"period": "day",
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestChatCompletionTool:
|
||||
"""Tests for chat_completion tool."""
|
||||
|
||||
def test_chat_completion_streaming_not_supported(
|
||||
self,
|
||||
test_client: TestClient,
|
||||
) -> None:
|
||||
"""Test that streaming returns info message."""
|
||||
with patch("server.get_model_router") as mock_router:
|
||||
mock_router.return_value = MagicMock()
|
||||
mock_router.return_value.select_model = AsyncMock(
|
||||
return_value=("claude-opus-4", MagicMock())
|
||||
)
|
||||
|
||||
with patch("server.get_cost_tracker") as mock_tracker:
|
||||
mock_tracker.return_value = MagicMock()
|
||||
mock_tracker.return_value.check_budget = AsyncMock(
|
||||
return_value=(True, 0.0, 100.0)
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "chat_completion",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_chat_completion_success(self, test_client: TestClient) -> None:
|
||||
"""Test successful chat completion."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hello, world!"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
with patch("server.get_model_router") as mock_router:
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.provider.value = "anthropic"
|
||||
mock_router.return_value = MagicMock()
|
||||
mock_router.return_value.select_model = AsyncMock(
|
||||
return_value=("claude-opus-4", mock_model_config)
|
||||
)
|
||||
|
||||
with patch("server.get_cost_tracker") as mock_tracker:
|
||||
mock_tracker.return_value = MagicMock()
|
||||
mock_tracker.return_value.check_budget = AsyncMock(
|
||||
return_value=(True, 0.0, 100.0)
|
||||
)
|
||||
mock_tracker.return_value.record_usage = AsyncMock()
|
||||
|
||||
with patch("server.get_provider") as mock_prov:
|
||||
mock_prov.return_value = MagicMock()
|
||||
mock_prov.return_value.router = MagicMock()
|
||||
mock_prov.return_value.router.acompletion = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with patch("server.get_circuit_registry") as mock_reg:
|
||||
mock_circuit = MagicMock()
|
||||
mock_circuit.record_success = AsyncMock()
|
||||
mock_reg.return_value = MagicMock()
|
||||
mock_reg.return_value.get_circuit_sync.return_value = (
|
||||
mock_circuit
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "chat_completion",
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
},
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
316
mcp-servers/llm-gateway/tests/test_streaming.py
Normal file
316
mcp-servers/llm-gateway/tests/test_streaming.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Tests for streaming module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models import StreamChunk, UsageStats
|
||||
from streaming import (
|
||||
StreamAccumulator,
|
||||
StreamBuffer,
|
||||
format_sse_chunk,
|
||||
format_sse_done,
|
||||
format_sse_error,
|
||||
stream_to_string,
|
||||
wrap_litellm_stream,
|
||||
)
|
||||
|
||||
|
||||
class TestStreamAccumulator:
|
||||
"""Tests for StreamAccumulator class."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""Test initial accumulator state."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
assert acc.request_id is not None
|
||||
assert acc.content == ""
|
||||
assert acc.chunks_received == 0
|
||||
assert acc.prompt_tokens == 0
|
||||
assert acc.completion_tokens == 0
|
||||
assert acc.model is None
|
||||
assert acc.finish_reason is None
|
||||
|
||||
def test_custom_request_id(self) -> None:
|
||||
"""Test accumulator with custom request ID."""
|
||||
acc = StreamAccumulator(request_id="custom-id")
|
||||
assert acc.request_id == "custom-id"
|
||||
|
||||
def test_add_chunk_text(self) -> None:
|
||||
"""Test adding text chunks."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
acc.add_chunk("Hello")
|
||||
acc.add_chunk(", ")
|
||||
acc.add_chunk("world!")
|
||||
|
||||
assert acc.content == "Hello, world!"
|
||||
assert acc.chunks_received == 3
|
||||
|
||||
def test_add_chunk_with_finish_reason(self) -> None:
|
||||
"""Test adding chunk with finish reason."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
acc.add_chunk("Final", finish_reason="stop")
|
||||
|
||||
assert acc.finish_reason == "stop"
|
||||
|
||||
def test_add_chunk_with_model(self) -> None:
|
||||
"""Test adding chunk with model info."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
acc.add_chunk("Text", model="claude-opus-4")
|
||||
|
||||
assert acc.model == "claude-opus-4"
|
||||
|
||||
def test_add_chunk_with_usage(self) -> None:
|
||||
"""Test adding chunk with usage stats."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
acc.add_chunk(
|
||||
"Text",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||
)
|
||||
|
||||
assert acc.prompt_tokens == 10
|
||||
assert acc.completion_tokens == 5
|
||||
assert acc.total_tokens == 15
|
||||
|
||||
def test_start_and_finish(self) -> None:
|
||||
"""Test start and finish timing."""
|
||||
acc = StreamAccumulator()
|
||||
|
||||
assert acc.duration_seconds is None
|
||||
|
||||
acc.start()
|
||||
acc.finish()
|
||||
|
||||
assert acc.duration_seconds is not None
|
||||
assert acc.duration_seconds >= 0
|
||||
|
||||
def test_get_usage_stats(self) -> None:
|
||||
"""Test getting usage stats."""
|
||||
acc = StreamAccumulator()
|
||||
acc.add_chunk("", usage={"prompt_tokens": 100, "completion_tokens": 50})
|
||||
|
||||
stats = acc.get_usage_stats(cost_usd=0.01)
|
||||
|
||||
assert stats.prompt_tokens == 100
|
||||
assert stats.completion_tokens == 50
|
||||
assert stats.total_tokens == 150
|
||||
assert stats.cost_usd == 0.01
|
||||
|
||||
|
||||
class TestWrapLiteLLMStream:
|
||||
"""Tests for wrap_litellm_stream function."""
|
||||
|
||||
async def test_wrap_stream_basic(self) -> None:
|
||||
"""Test wrapping a basic stream."""
|
||||
|
||||
# Create mock stream chunks
|
||||
async def mock_stream():
|
||||
chunk1 = MagicMock()
|
||||
chunk1.choices = [MagicMock()]
|
||||
chunk1.choices[0].delta = MagicMock()
|
||||
chunk1.choices[0].delta.content = "Hello"
|
||||
chunk1.choices[0].finish_reason = None
|
||||
chunk1.model = "test-model"
|
||||
chunk1.usage = None
|
||||
yield chunk1
|
||||
|
||||
chunk2 = MagicMock()
|
||||
chunk2.choices = [MagicMock()]
|
||||
chunk2.choices[0].delta = MagicMock()
|
||||
chunk2.choices[0].delta.content = " World"
|
||||
chunk2.choices[0].finish_reason = "stop"
|
||||
chunk2.model = "test-model"
|
||||
chunk2.usage = MagicMock()
|
||||
chunk2.usage.prompt_tokens = 5
|
||||
chunk2.usage.completion_tokens = 2
|
||||
yield chunk2
|
||||
|
||||
accumulator = StreamAccumulator()
|
||||
chunks = []
|
||||
|
||||
async for chunk in wrap_litellm_stream(mock_stream(), accumulator):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].delta == "Hello"
|
||||
assert chunks[1].delta == " World"
|
||||
assert chunks[1].finish_reason == "stop"
|
||||
assert accumulator.content == "Hello World"
|
||||
|
||||
async def test_wrap_stream_without_accumulator(self) -> None:
|
||||
"""Test wrapping stream without accumulator."""
|
||||
|
||||
async def mock_stream():
|
||||
chunk = MagicMock()
|
||||
chunk.choices = [MagicMock()]
|
||||
chunk.choices[0].delta = MagicMock()
|
||||
chunk.choices[0].delta.content = "Test"
|
||||
chunk.choices[0].finish_reason = None
|
||||
chunk.model = None
|
||||
chunk.usage = None
|
||||
yield chunk
|
||||
|
||||
chunks = []
|
||||
async for chunk in wrap_litellm_stream(mock_stream()):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 1
|
||||
|
||||
|
||||
class TestSSEFormatting:
|
||||
"""Tests for SSE formatting functions."""
|
||||
|
||||
def test_format_sse_chunk_basic(self) -> None:
|
||||
"""Test formatting basic chunk."""
|
||||
chunk = StreamChunk(id="chunk-1", delta="Hello")
|
||||
result = format_sse_chunk(chunk)
|
||||
|
||||
assert result.startswith("data: ")
|
||||
assert result.endswith("\n\n")
|
||||
|
||||
# Parse the JSON
|
||||
data = json.loads(result[6:-2])
|
||||
assert data["id"] == "chunk-1"
|
||||
assert data["delta"] == "Hello"
|
||||
|
||||
def test_format_sse_chunk_with_finish(self) -> None:
|
||||
"""Test formatting chunk with finish reason."""
|
||||
chunk = StreamChunk(
|
||||
id="chunk-2",
|
||||
delta="",
|
||||
finish_reason="stop",
|
||||
)
|
||||
result = format_sse_chunk(chunk)
|
||||
data = json.loads(result[6:-2])
|
||||
|
||||
assert data["finish_reason"] == "stop"
|
||||
|
||||
def test_format_sse_chunk_with_usage(self) -> None:
|
||||
"""Test formatting chunk with usage."""
|
||||
chunk = StreamChunk(
|
||||
id="chunk-3",
|
||||
delta="",
|
||||
finish_reason="stop",
|
||||
usage=UsageStats(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
cost_usd=0.001,
|
||||
),
|
||||
)
|
||||
result = format_sse_chunk(chunk)
|
||||
data = json.loads(result[6:-2])
|
||||
|
||||
assert "usage" in data
|
||||
assert data["usage"]["prompt_tokens"] == 10
|
||||
|
||||
def test_format_sse_done(self) -> None:
|
||||
"""Test formatting done message."""
|
||||
result = format_sse_done()
|
||||
assert result == "data: [DONE]\n\n"
|
||||
|
||||
def test_format_sse_error(self) -> None:
|
||||
"""Test formatting error message."""
|
||||
result = format_sse_error("Something went wrong", code="ERROR_CODE")
|
||||
data = json.loads(result[6:-2])
|
||||
|
||||
assert data["error"] == "Something went wrong"
|
||||
assert data["code"] == "ERROR_CODE"
|
||||
|
||||
def test_format_sse_error_without_code(self) -> None:
|
||||
"""Test formatting error without code."""
|
||||
result = format_sse_error("Error message")
|
||||
data = json.loads(result[6:-2])
|
||||
|
||||
assert data["error"] == "Error message"
|
||||
assert "code" not in data
|
||||
|
||||
|
||||
class TestStreamBuffer:
|
||||
"""Tests for StreamBuffer class."""
|
||||
|
||||
async def test_buffer_basic(self) -> None:
|
||||
"""Test basic buffer operations."""
|
||||
buffer = StreamBuffer(max_size=10)
|
||||
|
||||
# Producer
|
||||
async def produce():
|
||||
await buffer.put(StreamChunk(id="1", delta="Hello"))
|
||||
await buffer.put(StreamChunk(id="2", delta=" World"))
|
||||
await buffer.done()
|
||||
|
||||
# Consumer
|
||||
chunks = []
|
||||
asyncio.create_task(produce())
|
||||
|
||||
async for chunk in buffer:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].delta == "Hello"
|
||||
assert chunks[1].delta == " World"
|
||||
|
||||
async def test_buffer_error(self) -> None:
|
||||
"""Test buffer with error."""
|
||||
buffer = StreamBuffer()
|
||||
|
||||
async def produce():
|
||||
await buffer.put(StreamChunk(id="1", delta="Hello"))
|
||||
await buffer.error(ValueError("Test error"))
|
||||
|
||||
asyncio.create_task(produce())
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
async for _ in buffer:
|
||||
pass
|
||||
|
||||
async def test_buffer_put_after_done(self) -> None:
|
||||
"""Test putting after done raises."""
|
||||
buffer = StreamBuffer()
|
||||
await buffer.done()
|
||||
|
||||
with pytest.raises(RuntimeError, match="closed"):
|
||||
await buffer.put(StreamChunk(id="1", delta="Test"))
|
||||
|
||||
|
||||
class TestStreamToString:
|
||||
"""Tests for stream_to_string function."""
|
||||
|
||||
async def test_stream_to_string_basic(self) -> None:
|
||||
"""Test converting stream to string."""
|
||||
|
||||
async def mock_stream():
|
||||
yield StreamChunk(id="1", delta="Hello")
|
||||
yield StreamChunk(id="2", delta=" ")
|
||||
yield StreamChunk(id="3", delta="World")
|
||||
yield StreamChunk(
|
||||
id="4",
|
||||
delta="",
|
||||
finish_reason="stop",
|
||||
usage=UsageStats(prompt_tokens=5, completion_tokens=3),
|
||||
)
|
||||
|
||||
content, usage = await stream_to_string(mock_stream())
|
||||
|
||||
assert content == "Hello World"
|
||||
assert usage is not None
|
||||
assert usage.prompt_tokens == 5
|
||||
|
||||
async def test_stream_to_string_no_usage(self) -> None:
|
||||
"""Test stream without usage stats."""
|
||||
|
||||
async def mock_stream():
|
||||
yield StreamChunk(id="1", delta="Test")
|
||||
|
||||
content, usage = await stream_to_string(mock_stream())
|
||||
|
||||
assert content == "Test"
|
||||
assert usage is None
|
||||
2839
mcp-servers/llm-gateway/uv.lock
generated
Normal file
2839
mcp-servers/llm-gateway/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user