feat(llm-gateway): implement LLM Gateway MCP Server (#56)
Implements complete LLM Gateway MCP Server with: - FastMCP server with 4 tools: chat_completion, list_models, get_usage, count_tokens - LiteLLM Router with multi-provider failover chains - Circuit breaker pattern for fault tolerance - Redis-based cost tracking per project/agent - Comprehensive test suite (209 tests, 92% coverage) Model groups defined per ADR-004: - reasoning: claude-opus-4 → gpt-4.1 → gemini-2.5-pro - code: claude-sonnet-4 → gpt-4.1 → deepseek-coder - fast: claude-haiku → gpt-4.1-mini → gemini-2.0-flash 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
344
mcp-servers/llm-gateway/streaming.py
Normal file
344
mcp-servers/llm-gateway/streaming.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
Streaming support for LLM Gateway.
|
||||
|
||||
Provides async streaming wrappers for LiteLLM responses.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from models import StreamChunk, UsageStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamAccumulator:
|
||||
"""
|
||||
Accumulates streaming chunks for cost calculation.
|
||||
|
||||
Tracks:
|
||||
- Full content for final response
|
||||
- Token counts from chunks
|
||||
- Timing information
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str | None = None) -> None:
|
||||
"""
|
||||
Initialize accumulator.
|
||||
|
||||
Args:
|
||||
request_id: Optional request ID for tracking
|
||||
"""
|
||||
self.request_id = request_id or str(uuid.uuid4())
|
||||
self.content_parts: list[str] = []
|
||||
self.chunks_received = 0
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.model: str | None = None
|
||||
self.finish_reason: str | None = None
|
||||
self._started_at: float | None = None
|
||||
self._finished_at: float | None = None
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""Get accumulated content."""
|
||||
return "".join(self.content_parts)
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get total token count."""
|
||||
return self.prompt_tokens + self.completion_tokens
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float | None:
|
||||
"""Get stream duration in seconds."""
|
||||
if self._started_at is None or self._finished_at is None:
|
||||
return None
|
||||
return self._finished_at - self._started_at
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark stream start."""
|
||||
import time
|
||||
self._started_at = time.time()
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Mark stream finish."""
|
||||
import time
|
||||
self._finished_at = time.time()
|
||||
|
||||
def add_chunk(
|
||||
self,
|
||||
delta: str,
|
||||
finish_reason: str | None = None,
|
||||
model: str | None = None,
|
||||
usage: dict[str, int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a chunk to the accumulator.
|
||||
|
||||
Args:
|
||||
delta: Content delta
|
||||
finish_reason: Finish reason if this is the final chunk
|
||||
model: Model name
|
||||
usage: Usage stats if provided
|
||||
"""
|
||||
if delta:
|
||||
self.content_parts.append(delta)
|
||||
self.chunks_received += 1
|
||||
|
||||
if finish_reason:
|
||||
self.finish_reason = finish_reason
|
||||
|
||||
if model:
|
||||
self.model = model
|
||||
|
||||
if usage:
|
||||
self.prompt_tokens = usage.get("prompt_tokens", self.prompt_tokens)
|
||||
self.completion_tokens = usage.get(
|
||||
"completion_tokens", self.completion_tokens
|
||||
)
|
||||
|
||||
def get_usage_stats(self, cost_usd: float = 0.0) -> UsageStats:
|
||||
"""Get usage statistics."""
|
||||
return UsageStats(
|
||||
prompt_tokens=self.prompt_tokens,
|
||||
completion_tokens=self.completion_tokens,
|
||||
total_tokens=self.total_tokens,
|
||||
cost_usd=cost_usd,
|
||||
)
|
||||
|
||||
|
||||
async def wrap_litellm_stream(
|
||||
stream: AsyncIterator[Any],
|
||||
accumulator: StreamAccumulator | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Wrap a LiteLLM stream into StreamChunk objects.
|
||||
|
||||
Args:
|
||||
stream: LiteLLM async stream
|
||||
accumulator: Optional accumulator for tracking
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
"""
|
||||
if accumulator:
|
||||
accumulator.start()
|
||||
|
||||
chunk_id = 0
|
||||
try:
|
||||
async for chunk in stream:
|
||||
chunk_id += 1
|
||||
|
||||
# Extract data from LiteLLM chunk
|
||||
delta = ""
|
||||
finish_reason = None
|
||||
usage = None
|
||||
model = None
|
||||
|
||||
# Handle different chunk formats
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta"):
|
||||
delta = getattr(choice.delta, "content", "") or ""
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
|
||||
if hasattr(chunk, "model"):
|
||||
model = chunk.model
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage = {
|
||||
"prompt_tokens": getattr(chunk.usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(chunk.usage, "completion_tokens", 0),
|
||||
}
|
||||
|
||||
# Update accumulator
|
||||
if accumulator:
|
||||
accumulator.add_chunk(
|
||||
delta=delta,
|
||||
finish_reason=finish_reason,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Create StreamChunk
|
||||
stream_chunk = StreamChunk(
|
||||
id=f"{accumulator.request_id if accumulator else 'stream'}-{chunk_id}",
|
||||
delta=delta,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
# Add usage on final chunk
|
||||
if finish_reason and accumulator:
|
||||
stream_chunk.usage = accumulator.get_usage_stats()
|
||||
|
||||
yield stream_chunk
|
||||
|
||||
finally:
|
||||
if accumulator:
|
||||
accumulator.finish()
|
||||
|
||||
|
||||
def format_sse_chunk(chunk: StreamChunk) -> str:
|
||||
"""
|
||||
Format a StreamChunk as SSE data.
|
||||
|
||||
Args:
|
||||
chunk: StreamChunk to format
|
||||
|
||||
Returns:
|
||||
SSE-formatted string
|
||||
"""
|
||||
data = {
|
||||
"id": chunk.id,
|
||||
"delta": chunk.delta,
|
||||
}
|
||||
if chunk.finish_reason:
|
||||
data["finish_reason"] = chunk.finish_reason
|
||||
if chunk.usage:
|
||||
data["usage"] = chunk.usage.model_dump()
|
||||
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def format_sse_done() -> str:
|
||||
"""Format SSE done message."""
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def format_sse_error(error: str, code: str | None = None) -> str:
|
||||
"""
|
||||
Format an error as SSE data.
|
||||
|
||||
Args:
|
||||
error: Error message
|
||||
code: Error code
|
||||
|
||||
Returns:
|
||||
SSE-formatted error string
|
||||
"""
|
||||
data = {"error": error}
|
||||
if code:
|
||||
data["code"] = code
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
class StreamBuffer:
|
||||
"""
|
||||
Buffer for streaming responses with backpressure handling.
|
||||
|
||||
Useful when producing chunks faster than they can be consumed.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 100) -> None:
|
||||
"""
|
||||
Initialize buffer.
|
||||
|
||||
Args:
|
||||
max_size: Maximum buffer size
|
||||
"""
|
||||
self._queue: asyncio.Queue[StreamChunk | None] = asyncio.Queue(maxsize=max_size)
|
||||
self._done = False
|
||||
self._error: Exception | None = None
|
||||
|
||||
async def put(self, chunk: StreamChunk) -> None:
|
||||
"""
|
||||
Put a chunk in the buffer.
|
||||
|
||||
Args:
|
||||
chunk: Chunk to buffer
|
||||
"""
|
||||
if self._done:
|
||||
raise RuntimeError("Buffer is closed")
|
||||
await self._queue.put(chunk)
|
||||
|
||||
async def done(self) -> None:
|
||||
"""Signal that streaming is complete."""
|
||||
self._done = True
|
||||
await self._queue.put(None)
|
||||
|
||||
async def error(self, err: Exception) -> None:
|
||||
"""Signal an error occurred."""
|
||||
self._error = err
|
||||
self._done = True
|
||||
await self._queue.put(None)
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over buffered chunks."""
|
||||
while True:
|
||||
chunk = await self._queue.get()
|
||||
if chunk is None:
|
||||
if self._error:
|
||||
raise self._error
|
||||
return
|
||||
yield chunk
|
||||
|
||||
|
||||
async def stream_to_string(stream: AsyncIterator[StreamChunk]) -> tuple[str, UsageStats | None]:
|
||||
"""
|
||||
Consume a stream and return full content.
|
||||
|
||||
Args:
|
||||
stream: Stream to consume
|
||||
|
||||
Returns:
|
||||
Tuple of (content, usage_stats)
|
||||
"""
|
||||
parts: list[str] = []
|
||||
usage: UsageStats | None = None
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.delta:
|
||||
parts.append(chunk.delta)
|
||||
if chunk.usage:
|
||||
usage = chunk.usage
|
||||
|
||||
return "".join(parts), usage
|
||||
|
||||
|
||||
async def merge_streams(
|
||||
*streams: AsyncIterator[StreamChunk],
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Merge multiple streams into one.
|
||||
|
||||
Useful for parallel requests where results should be combined.
|
||||
|
||||
Args:
|
||||
*streams: Streams to merge
|
||||
|
||||
Yields:
|
||||
Chunks from all streams in arrival order
|
||||
"""
|
||||
pending: set[asyncio.Task[tuple[int, StreamChunk | None]]] = set()
|
||||
|
||||
async def next_chunk(
|
||||
idx: int, stream: AsyncIterator[StreamChunk]
|
||||
) -> tuple[int, StreamChunk | None]:
|
||||
try:
|
||||
return idx, await stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
return idx, None
|
||||
|
||||
# Start initial tasks
|
||||
active_streams = list(streams)
|
||||
for idx, stream in enumerate(active_streams):
|
||||
task = asyncio.create_task(next_chunk(idx, stream))
|
||||
pending.add(task)
|
||||
|
||||
while pending:
|
||||
done, pending = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
for task in done:
|
||||
idx, chunk = task.result()
|
||||
if chunk is not None:
|
||||
yield chunk
|
||||
# Schedule next chunk from this stream
|
||||
new_task = asyncio.create_task(next_chunk(idx, active_streams[idx]))
|
||||
pending.add(new_task)
|
||||
Reference in New Issue
Block a user