""" Streaming support for LLM Gateway. Provides async streaming wrappers for LiteLLM responses. """ import asyncio import json import logging import uuid from collections.abc import AsyncIterator from typing import Any from models import StreamChunk, UsageStats logger = logging.getLogger(__name__) class StreamAccumulator: """ Accumulates streaming chunks for cost calculation. Tracks: - Full content for final response - Token counts from chunks - Timing information """ def __init__(self, request_id: str | None = None) -> None: """ Initialize accumulator. Args: request_id: Optional request ID for tracking """ self.request_id = request_id or str(uuid.uuid4()) self.content_parts: list[str] = [] self.chunks_received = 0 self.prompt_tokens = 0 self.completion_tokens = 0 self.model: str | None = None self.finish_reason: str | None = None self._started_at: float | None = None self._finished_at: float | None = None @property def content(self) -> str: """Get accumulated content.""" return "".join(self.content_parts) @property def total_tokens(self) -> int: """Get total token count.""" return self.prompt_tokens + self.completion_tokens @property def duration_seconds(self) -> float | None: """Get stream duration in seconds.""" if self._started_at is None or self._finished_at is None: return None return self._finished_at - self._started_at def start(self) -> None: """Mark stream start.""" import time self._started_at = time.time() def finish(self) -> None: """Mark stream finish.""" import time self._finished_at = time.time() def add_chunk( self, delta: str, finish_reason: str | None = None, model: str | None = None, usage: dict[str, int] | None = None, ) -> None: """ Add a chunk to the accumulator. Args: delta: Content delta finish_reason: Finish reason if this is the final chunk model: Model name usage: Usage stats if provided """ if delta: self.content_parts.append(delta) self.chunks_received += 1 if finish_reason: self.finish_reason = finish_reason if model: self.model = model if usage: self.prompt_tokens = usage.get("prompt_tokens", self.prompt_tokens) self.completion_tokens = usage.get( "completion_tokens", self.completion_tokens ) def get_usage_stats(self, cost_usd: float = 0.0) -> UsageStats: """Get usage statistics.""" return UsageStats( prompt_tokens=self.prompt_tokens, completion_tokens=self.completion_tokens, total_tokens=self.total_tokens, cost_usd=cost_usd, ) async def wrap_litellm_stream( stream: AsyncIterator[Any], accumulator: StreamAccumulator | None = None, ) -> AsyncIterator[StreamChunk]: """ Wrap a LiteLLM stream into StreamChunk objects. Args: stream: LiteLLM async stream accumulator: Optional accumulator for tracking Yields: StreamChunk objects """ if accumulator: accumulator.start() chunk_id = 0 try: async for chunk in stream: chunk_id += 1 # Extract data from LiteLLM chunk delta = "" finish_reason = None usage = None model = None # Handle different chunk formats if hasattr(chunk, "choices") and chunk.choices: choice = chunk.choices[0] if hasattr(choice, "delta"): delta = getattr(choice.delta, "content", "") or "" finish_reason = getattr(choice, "finish_reason", None) if hasattr(chunk, "model"): model = chunk.model if hasattr(chunk, "usage") and chunk.usage: usage = { "prompt_tokens": getattr(chunk.usage, "prompt_tokens", 0), "completion_tokens": getattr(chunk.usage, "completion_tokens", 0), } # Update accumulator if accumulator: accumulator.add_chunk( delta=delta, finish_reason=finish_reason, model=model, usage=usage, ) # Create StreamChunk stream_chunk = StreamChunk( id=f"{accumulator.request_id if accumulator else 'stream'}-{chunk_id}", delta=delta, finish_reason=finish_reason, ) # Add usage on final chunk if finish_reason and accumulator: stream_chunk.usage = accumulator.get_usage_stats() yield stream_chunk finally: if accumulator: accumulator.finish() def format_sse_chunk(chunk: StreamChunk) -> str: """ Format a StreamChunk as SSE data. Args: chunk: StreamChunk to format Returns: SSE-formatted string """ data: dict[str, Any] = { "id": chunk.id, "delta": chunk.delta, } if chunk.finish_reason: data["finish_reason"] = chunk.finish_reason if chunk.usage: data["usage"] = chunk.usage.model_dump() return f"data: {json.dumps(data)}\n\n" def format_sse_done() -> str: """Format SSE done message.""" return "data: [DONE]\n\n" def format_sse_error(error: str, code: str | None = None) -> str: """ Format an error as SSE data. Args: error: Error message code: Error code Returns: SSE-formatted error string """ data = {"error": error} if code: data["code"] = code return f"data: {json.dumps(data)}\n\n" class StreamBuffer: """ Buffer for streaming responses with backpressure handling. Useful when producing chunks faster than they can be consumed. """ def __init__(self, max_size: int = 100) -> None: """ Initialize buffer. Args: max_size: Maximum buffer size """ self._queue: asyncio.Queue[StreamChunk | None] = asyncio.Queue(maxsize=max_size) self._done = False self._error: Exception | None = None async def put(self, chunk: StreamChunk) -> None: """ Put a chunk in the buffer. Args: chunk: Chunk to buffer """ if self._done: raise RuntimeError("Buffer is closed") await self._queue.put(chunk) async def done(self) -> None: """Signal that streaming is complete.""" self._done = True await self._queue.put(None) async def error(self, err: Exception) -> None: """Signal an error occurred.""" self._error = err self._done = True await self._queue.put(None) async def __aiter__(self) -> AsyncIterator[StreamChunk]: """Iterate over buffered chunks.""" while True: chunk = await self._queue.get() if chunk is None: if self._error: raise self._error return yield chunk async def stream_to_string( stream: AsyncIterator[StreamChunk], ) -> tuple[str, UsageStats | None]: """ Consume a stream and return full content. Args: stream: Stream to consume Returns: Tuple of (content, usage_stats) """ parts: list[str] = [] usage: UsageStats | None = None async for chunk in stream: if chunk.delta: parts.append(chunk.delta) if chunk.usage: usage = chunk.usage return "".join(parts), usage async def merge_streams( *streams: AsyncIterator[StreamChunk], ) -> AsyncIterator[StreamChunk]: """ Merge multiple streams into one. Useful for parallel requests where results should be combined. Args: *streams: Streams to merge Yields: Chunks from all streams in arrival order """ pending: set[asyncio.Task[tuple[int, StreamChunk | None]]] = set() async def next_chunk( idx: int, stream: AsyncIterator[StreamChunk] ) -> tuple[int, StreamChunk | None]: try: return idx, await stream.__anext__() except StopAsyncIteration: return idx, None # Start initial tasks active_streams = list(streams) for idx, stream in enumerate(active_streams): task = asyncio.create_task(next_chunk(idx, stream)) pending.add(task) while pending: done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: idx, chunk = task.result() if chunk is not None: yield chunk # Schedule next chunk from this stream new_task = asyncio.create_task(next_chunk(idx, active_streams[idx])) pending.add(new_task)