forked from cardosofelipe/fast-next-template
- Added research findings and recommendations as separate SPIKE documents in `docs/spikes/`: - `SPIKE-005-llm-provider-abstraction.md`: Research on unified abstraction for LLM providers with failover, cost tracking, and caching strategies. - `SPIKE-001-mcp-integration-pattern.md`: Optimal pattern for integrating MCP with project/agent scoping and authentication strategies. - `SPIKE-003-realtime-updates.md`: Evaluation of SSE vs WebSocket for real-time updates, aligned with use-case needs. - Focused on aligning implementation architectures with scalability, efficiency, and user needs. - Documentation intended to inform upcoming ADRs.
14 KiB
14 KiB
SPIKE-005: LLM Provider Abstraction
Status: Completed Date: 2025-12-29 Author: Architecture Team Related Issue: #5
Objective
Research the best approach for unified LLM provider abstraction with support for multiple providers, automatic failover, and cost tracking.
Research Questions
- What libraries exist for unified LLM access?
- How to implement automatic failover between providers?
- How to track token usage and costs per agent/project?
- What caching strategies can reduce API costs?
Findings
1. LiteLLM - Recommended Solution
LiteLLM provides a unified interface to 100+ LLM providers using the OpenAI SDK format.
Key Features:
- Unified API across providers (Anthropic, OpenAI, local, etc.)
- Built-in failover and load balancing
- Token counting and cost tracking
- Streaming support
- Async support
- Caching with Redis
Installation:
pip install litellm
2. Basic Usage
from litellm import completion, acompletion
import litellm
# Configure providers
litellm.api_key = os.getenv("ANTHROPIC_API_KEY")
litellm.set_verbose = True # For debugging
# Synchronous call
response = completion(
model="claude-3-5-sonnet-20241022",
messages=[{"role": "user", "content": "Hello!"}]
)
# Async call (for FastAPI)
response = await acompletion(
model="claude-3-5-sonnet-20241022",
messages=[{"role": "user", "content": "Hello!"}]
)
3. Model Naming Convention
LiteLLM uses prefixed model names:
| Provider | Model Format |
|---|---|
| Anthropic | claude-3-5-sonnet-20241022 |
| OpenAI | gpt-4-turbo |
| Azure OpenAI | azure/deployment-name |
| Ollama | ollama/llama3 |
| Together AI | together_ai/togethercomputer/llama-2-70b |
4. Failover Configuration
from litellm import Router
# Define model list with fallbacks
model_list = [
{
"model_name": "primary-agent",
"litellm_params": {
"model": "claude-3-5-sonnet-20241022",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
"model_info": {"id": 1}
},
{
"model_name": "primary-agent", # Same name = fallback
"litellm_params": {
"model": "gpt-4-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"model_info": {"id": 2}
},
{
"model_name": "primary-agent",
"litellm_params": {
"model": "ollama/llama3",
"api_base": "http://localhost:11434",
},
"model_info": {"id": 3}
}
]
# Initialize router with failover
router = Router(
model_list=model_list,
fallbacks=[
{"primary-agent": ["primary-agent"]} # Try all models with same name
],
routing_strategy="simple-shuffle", # or "latency-based-routing"
num_retries=3,
retry_after=5, # seconds
timeout=60,
)
# Use router
response = await router.acompletion(
model="primary-agent",
messages=[{"role": "user", "content": "Hello!"}]
)
5. Syndarix LLM Gateway Architecture
# app/services/llm_gateway.py
from litellm import Router, acompletion
from app.core.config import settings
from app.models.agent import AgentType
from app.services.cost_tracker import CostTracker
from app.services.events import EventBus
class LLMGateway:
"""Unified LLM gateway with failover and cost tracking."""
def __init__(self):
self.router = self._build_router()
self.cost_tracker = CostTracker()
self.event_bus = EventBus()
def _build_router(self) -> Router:
"""Build LiteLLM router from configuration."""
model_list = []
# Add Anthropic models
if settings.ANTHROPIC_API_KEY:
model_list.extend([
{
"model_name": "high-reasoning",
"litellm_params": {
"model": "claude-3-5-sonnet-20241022",
"api_key": settings.ANTHROPIC_API_KEY,
}
},
{
"model_name": "fast-response",
"litellm_params": {
"model": "claude-3-haiku-20240307",
"api_key": settings.ANTHROPIC_API_KEY,
}
}
])
# Add OpenAI fallbacks
if settings.OPENAI_API_KEY:
model_list.extend([
{
"model_name": "high-reasoning",
"litellm_params": {
"model": "gpt-4-turbo",
"api_key": settings.OPENAI_API_KEY,
}
},
{
"model_name": "fast-response",
"litellm_params": {
"model": "gpt-4o-mini",
"api_key": settings.OPENAI_API_KEY,
}
}
])
# Add local models (Ollama)
if settings.OLLAMA_URL:
model_list.append({
"model_name": "local-fallback",
"litellm_params": {
"model": "ollama/llama3",
"api_base": settings.OLLAMA_URL,
}
})
return Router(
model_list=model_list,
fallbacks=[
{"high-reasoning": ["high-reasoning", "local-fallback"]},
{"fast-response": ["fast-response", "local-fallback"]},
],
routing_strategy="latency-based-routing",
num_retries=3,
timeout=120,
)
async def complete(
self,
agent_id: str,
project_id: str,
messages: list[dict],
model_preference: str = "high-reasoning",
stream: bool = False,
**kwargs
) -> dict:
"""
Generate a completion with automatic failover and cost tracking.
Args:
agent_id: The calling agent's ID
project_id: The project context
messages: Chat messages
model_preference: "high-reasoning" or "fast-response"
stream: Whether to stream the response
**kwargs: Additional LiteLLM parameters
Returns:
Completion response dictionary
"""
try:
if stream:
return self._stream_completion(
agent_id, project_id, messages, model_preference, **kwargs
)
response = await self.router.acompletion(
model=model_preference,
messages=messages,
**kwargs
)
# Track usage
await self._track_usage(
agent_id=agent_id,
project_id=project_id,
model=response.model,
usage=response.usage,
)
return {
"content": response.choices[0].message.content,
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
}
except Exception as e:
# Publish error event
await self.event_bus.publish(f"project:{project_id}", {
"type": "llm_error",
"agent_id": agent_id,
"error": str(e)
})
raise
async def _stream_completion(
self,
agent_id: str,
project_id: str,
messages: list[dict],
model_preference: str,
**kwargs
):
"""Stream a completion response."""
response = await self.router.acompletion(
model=model_preference,
messages=messages,
stream=True,
**kwargs
)
async for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def _track_usage(
self,
agent_id: str,
project_id: str,
model: str,
usage: dict
):
"""Track token usage and costs."""
await self.cost_tracker.record_usage(
agent_id=agent_id,
project_id=project_id,
model=model,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
)
6. Cost Tracking
# app/services/cost_tracker.py
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.usage import TokenUsage
from datetime import datetime
# Cost per 1M tokens (approximate)
MODEL_COSTS = {
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
"gpt-4-turbo": {"input": 10.00, "output": 30.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"ollama/llama3": {"input": 0.00, "output": 0.00}, # Local
}
class CostTracker:
def __init__(self, db: AsyncSession):
self.db = db
async def record_usage(
self,
agent_id: str,
project_id: str,
model: str,
prompt_tokens: int,
completion_tokens: int,
):
"""Record token usage and calculate cost."""
costs = MODEL_COSTS.get(model, {"input": 0, "output": 0})
input_cost = (prompt_tokens / 1_000_000) * costs["input"]
output_cost = (completion_tokens / 1_000_000) * costs["output"]
total_cost = input_cost + output_cost
usage = TokenUsage(
agent_id=agent_id,
project_id=project_id,
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
cost_usd=total_cost,
timestamp=datetime.utcnow(),
)
self.db.add(usage)
await self.db.commit()
async def get_project_usage(
self,
project_id: str,
start_date: datetime = None,
end_date: datetime = None,
) -> dict:
"""Get usage summary for a project."""
# Query aggregated usage
...
async def check_budget(
self,
project_id: str,
budget_limit: float,
) -> bool:
"""Check if project is within budget."""
usage = await self.get_project_usage(project_id)
return usage["total_cost_usd"] < budget_limit
7. Caching with Redis
import litellm
from litellm import Cache
# Configure Redis cache
litellm.cache = Cache(
type="redis",
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
password=settings.REDIS_PASSWORD,
)
# Enable caching
litellm.enable_cache()
# Cached completions (same input = cached response)
response = await litellm.acompletion(
model="claude-3-5-sonnet-20241022",
messages=[{"role": "user", "content": "What is 2+2?"}],
cache={"ttl": 3600} # Cache for 1 hour
)
8. Agent Type Model Mapping
# app/models/agent_type.py
from sqlalchemy import Column, String, Enum as SQLEnum
from app.db.base import Base
class ModelPreference(str, Enum):
HIGH_REASONING = "high-reasoning"
FAST_RESPONSE = "fast-response"
COST_OPTIMIZED = "cost-optimized"
class AgentType(Base):
__tablename__ = "agent_types"
id = Column(UUID, primary_key=True)
name = Column(String(50), unique=True)
role = Column(String(50))
# LLM configuration
model_preference = Column(
SQLEnum(ModelPreference),
default=ModelPreference.HIGH_REASONING
)
max_tokens = Column(Integer, default=4096)
temperature = Column(Float, default=0.7)
# System prompt
system_prompt = Column(Text)
# Mapping agent types to models
AGENT_MODEL_MAPPING = {
"Product Owner": ModelPreference.HIGH_REASONING,
"Project Manager": ModelPreference.FAST_RESPONSE,
"Business Analyst": ModelPreference.HIGH_REASONING,
"Software Architect": ModelPreference.HIGH_REASONING,
"Software Engineer": ModelPreference.HIGH_REASONING,
"UI/UX Designer": ModelPreference.HIGH_REASONING,
"QA Engineer": ModelPreference.FAST_RESPONSE,
"DevOps Engineer": ModelPreference.FAST_RESPONSE,
"AI/ML Engineer": ModelPreference.HIGH_REASONING,
"Security Expert": ModelPreference.HIGH_REASONING,
}
Rate Limiting Strategy
from litellm import Router
import asyncio
# Configure rate limits per model
router = Router(
model_list=model_list,
redis_host=settings.REDIS_HOST,
redis_port=settings.REDIS_PORT,
routing_strategy="usage-based-routing", # Route based on rate limits
)
# Custom rate limiter
class RateLimiter:
def __init__(self, requests_per_minute: int = 60):
self.rpm = requests_per_minute
self.semaphore = asyncio.Semaphore(requests_per_minute)
async def acquire(self):
await self.semaphore.acquire()
# Release after 60 seconds
asyncio.create_task(self._release_after(60))
async def _release_after(self, seconds: int):
await asyncio.sleep(seconds)
self.semaphore.release()
Recommendations
-
Use LiteLLM as the unified abstraction layer
- Simplifies multi-provider support
- Built-in failover and retry
- Consistent API across providers
-
Configure model groups by use case
high-reasoning: Complex analysis, architecture decisionsfast-response: Quick tasks, simple queriescost-optimized: Non-critical, high-volume tasks
-
Implement automatic failover chain
- Primary: Claude 3.5 Sonnet
- Fallback 1: GPT-4 Turbo
- Fallback 2: Local Llama 3 (if available)
-
Track all usage and costs
- Per agent, per project
- Set budget alerts
- Generate usage reports
-
Cache frequently repeated queries
- Use Redis-backed cache
- Cache embeddings for RAG
- Cache deterministic transformations
References
Decision
Adopt LiteLLM as the unified LLM abstraction layer with automatic failover, usage-based routing, and Redis-backed caching.
Spike completed. Findings will inform ADR-004: LLM Provider Integration Architecture.