Files
Felipe Cardoso f482559e15 fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:56:05 +01:00

347 lines
9.0 KiB
Python

"""
Streaming support for LLM Gateway.
Provides async streaming wrappers for LiteLLM responses.
"""
import asyncio
import json
import logging
import uuid
from collections.abc import AsyncIterator
from typing import Any
from models import StreamChunk, UsageStats
logger = logging.getLogger(__name__)
class StreamAccumulator:
"""
Accumulates streaming chunks for cost calculation.
Tracks:
- Full content for final response
- Token counts from chunks
- Timing information
"""
def __init__(self, request_id: str | None = None) -> None:
"""
Initialize accumulator.
Args:
request_id: Optional request ID for tracking
"""
self.request_id = request_id or str(uuid.uuid4())
self.content_parts: list[str] = []
self.chunks_received = 0
self.prompt_tokens = 0
self.completion_tokens = 0
self.model: str | None = None
self.finish_reason: str | None = None
self._started_at: float | None = None
self._finished_at: float | None = None
@property
def content(self) -> str:
"""Get accumulated content."""
return "".join(self.content_parts)
@property
def total_tokens(self) -> int:
"""Get total token count."""
return self.prompt_tokens + self.completion_tokens
@property
def duration_seconds(self) -> float | None:
"""Get stream duration in seconds."""
if self._started_at is None or self._finished_at is None:
return None
return self._finished_at - self._started_at
def start(self) -> None:
"""Mark stream start."""
import time
self._started_at = time.time()
def finish(self) -> None:
"""Mark stream finish."""
import time
self._finished_at = time.time()
def add_chunk(
self,
delta: str,
finish_reason: str | None = None,
model: str | None = None,
usage: dict[str, int] | None = None,
) -> None:
"""
Add a chunk to the accumulator.
Args:
delta: Content delta
finish_reason: Finish reason if this is the final chunk
model: Model name
usage: Usage stats if provided
"""
if delta:
self.content_parts.append(delta)
self.chunks_received += 1
if finish_reason:
self.finish_reason = finish_reason
if model:
self.model = model
if usage:
self.prompt_tokens = usage.get("prompt_tokens", self.prompt_tokens)
self.completion_tokens = usage.get(
"completion_tokens", self.completion_tokens
)
def get_usage_stats(self, cost_usd: float = 0.0) -> UsageStats:
"""Get usage statistics."""
return UsageStats(
prompt_tokens=self.prompt_tokens,
completion_tokens=self.completion_tokens,
total_tokens=self.total_tokens,
cost_usd=cost_usd,
)
async def wrap_litellm_stream(
stream: AsyncIterator[Any],
accumulator: StreamAccumulator | None = None,
) -> AsyncIterator[StreamChunk]:
"""
Wrap a LiteLLM stream into StreamChunk objects.
Args:
stream: LiteLLM async stream
accumulator: Optional accumulator for tracking
Yields:
StreamChunk objects
"""
if accumulator:
accumulator.start()
chunk_id = 0
try:
async for chunk in stream:
chunk_id += 1
# Extract data from LiteLLM chunk
delta = ""
finish_reason = None
usage = None
model = None
# Handle different chunk formats
if hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta"):
delta = getattr(choice.delta, "content", "") or ""
finish_reason = getattr(choice, "finish_reason", None)
if hasattr(chunk, "model"):
model = chunk.model
if hasattr(chunk, "usage") and chunk.usage:
usage = {
"prompt_tokens": getattr(chunk.usage, "prompt_tokens", 0),
"completion_tokens": getattr(chunk.usage, "completion_tokens", 0),
}
# Update accumulator
if accumulator:
accumulator.add_chunk(
delta=delta,
finish_reason=finish_reason,
model=model,
usage=usage,
)
# Create StreamChunk
stream_chunk = StreamChunk(
id=f"{accumulator.request_id if accumulator else 'stream'}-{chunk_id}",
delta=delta,
finish_reason=finish_reason,
)
# Add usage on final chunk
if finish_reason and accumulator:
stream_chunk.usage = accumulator.get_usage_stats()
yield stream_chunk
finally:
if accumulator:
accumulator.finish()
def format_sse_chunk(chunk: StreamChunk) -> str:
"""
Format a StreamChunk as SSE data.
Args:
chunk: StreamChunk to format
Returns:
SSE-formatted string
"""
data: dict[str, Any] = {
"id": chunk.id,
"delta": chunk.delta,
}
if chunk.finish_reason:
data["finish_reason"] = chunk.finish_reason
if chunk.usage:
data["usage"] = chunk.usage.model_dump()
return f"data: {json.dumps(data)}\n\n"
def format_sse_done() -> str:
"""Format SSE done message."""
return "data: [DONE]\n\n"
def format_sse_error(error: str, code: str | None = None) -> str:
"""
Format an error as SSE data.
Args:
error: Error message
code: Error code
Returns:
SSE-formatted error string
"""
data = {"error": error}
if code:
data["code"] = code
return f"data: {json.dumps(data)}\n\n"
class StreamBuffer:
"""
Buffer for streaming responses with backpressure handling.
Useful when producing chunks faster than they can be consumed.
"""
def __init__(self, max_size: int = 100) -> None:
"""
Initialize buffer.
Args:
max_size: Maximum buffer size
"""
self._queue: asyncio.Queue[StreamChunk | None] = asyncio.Queue(maxsize=max_size)
self._done = False
self._error: Exception | None = None
async def put(self, chunk: StreamChunk) -> None:
"""
Put a chunk in the buffer.
Args:
chunk: Chunk to buffer
"""
if self._done:
raise RuntimeError("Buffer is closed")
await self._queue.put(chunk)
async def done(self) -> None:
"""Signal that streaming is complete."""
self._done = True
await self._queue.put(None)
async def error(self, err: Exception) -> None:
"""Signal an error occurred."""
self._error = err
self._done = True
await self._queue.put(None)
async def __aiter__(self) -> AsyncIterator[StreamChunk]:
"""Iterate over buffered chunks."""
while True:
chunk = await self._queue.get()
if chunk is None:
if self._error:
raise self._error
return
yield chunk
async def stream_to_string(
stream: AsyncIterator[StreamChunk],
) -> tuple[str, UsageStats | None]:
"""
Consume a stream and return full content.
Args:
stream: Stream to consume
Returns:
Tuple of (content, usage_stats)
"""
parts: list[str] = []
usage: UsageStats | None = None
async for chunk in stream:
if chunk.delta:
parts.append(chunk.delta)
if chunk.usage:
usage = chunk.usage
return "".join(parts), usage
async def merge_streams(
*streams: AsyncIterator[StreamChunk],
) -> AsyncIterator[StreamChunk]:
"""
Merge multiple streams into one.
Useful for parallel requests where results should be combined.
Args:
*streams: Streams to merge
Yields:
Chunks from all streams in arrival order
"""
pending: set[asyncio.Task[tuple[int, StreamChunk | None]]] = set()
async def next_chunk(
idx: int, stream: AsyncIterator[StreamChunk]
) -> tuple[int, StreamChunk | None]:
try:
return idx, await stream.__anext__()
except StopAsyncIteration:
return idx, None
# Start initial tasks
active_streams = list(streams)
for idx, stream in enumerate(active_streams):
task = asyncio.create_task(next_chunk(idx, stream))
pending.add(task)
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
idx, chunk = task.result()
if chunk is not None:
yield chunk
# Schedule next chunk from this stream
new_task = asyncio.create_task(next_chunk(idx, active_streams[idx]))
pending.add(new_task)