forked from cardosofelipe/fast-next-template
- Add type annotations for mypy compliance - Use UTC-aware datetimes consistently (datetime.now(UTC)) - Add type: ignore comments for LiteLLM incomplete stubs - Fix import ordering and formatting - Update pyproject.toml mypy configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
347 lines
9.0 KiB
Python
347 lines
9.0 KiB
Python
"""
|
|
Streaming support for LLM Gateway.
|
|
|
|
Provides async streaming wrappers for LiteLLM responses.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
from models import StreamChunk, UsageStats
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StreamAccumulator:
|
|
"""
|
|
Accumulates streaming chunks for cost calculation.
|
|
|
|
Tracks:
|
|
- Full content for final response
|
|
- Token counts from chunks
|
|
- Timing information
|
|
"""
|
|
|
|
def __init__(self, request_id: str | None = None) -> None:
|
|
"""
|
|
Initialize accumulator.
|
|
|
|
Args:
|
|
request_id: Optional request ID for tracking
|
|
"""
|
|
self.request_id = request_id or str(uuid.uuid4())
|
|
self.content_parts: list[str] = []
|
|
self.chunks_received = 0
|
|
self.prompt_tokens = 0
|
|
self.completion_tokens = 0
|
|
self.model: str | None = None
|
|
self.finish_reason: str | None = None
|
|
self._started_at: float | None = None
|
|
self._finished_at: float | None = None
|
|
|
|
@property
|
|
def content(self) -> str:
|
|
"""Get accumulated content."""
|
|
return "".join(self.content_parts)
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
"""Get total token count."""
|
|
return self.prompt_tokens + self.completion_tokens
|
|
|
|
@property
|
|
def duration_seconds(self) -> float | None:
|
|
"""Get stream duration in seconds."""
|
|
if self._started_at is None or self._finished_at is None:
|
|
return None
|
|
return self._finished_at - self._started_at
|
|
|
|
def start(self) -> None:
|
|
"""Mark stream start."""
|
|
import time
|
|
|
|
self._started_at = time.time()
|
|
|
|
def finish(self) -> None:
|
|
"""Mark stream finish."""
|
|
import time
|
|
|
|
self._finished_at = time.time()
|
|
|
|
def add_chunk(
|
|
self,
|
|
delta: str,
|
|
finish_reason: str | None = None,
|
|
model: str | None = None,
|
|
usage: dict[str, int] | None = None,
|
|
) -> None:
|
|
"""
|
|
Add a chunk to the accumulator.
|
|
|
|
Args:
|
|
delta: Content delta
|
|
finish_reason: Finish reason if this is the final chunk
|
|
model: Model name
|
|
usage: Usage stats if provided
|
|
"""
|
|
if delta:
|
|
self.content_parts.append(delta)
|
|
self.chunks_received += 1
|
|
|
|
if finish_reason:
|
|
self.finish_reason = finish_reason
|
|
|
|
if model:
|
|
self.model = model
|
|
|
|
if usage:
|
|
self.prompt_tokens = usage.get("prompt_tokens", self.prompt_tokens)
|
|
self.completion_tokens = usage.get(
|
|
"completion_tokens", self.completion_tokens
|
|
)
|
|
|
|
def get_usage_stats(self, cost_usd: float = 0.0) -> UsageStats:
|
|
"""Get usage statistics."""
|
|
return UsageStats(
|
|
prompt_tokens=self.prompt_tokens,
|
|
completion_tokens=self.completion_tokens,
|
|
total_tokens=self.total_tokens,
|
|
cost_usd=cost_usd,
|
|
)
|
|
|
|
|
|
async def wrap_litellm_stream(
|
|
stream: AsyncIterator[Any],
|
|
accumulator: StreamAccumulator | None = None,
|
|
) -> AsyncIterator[StreamChunk]:
|
|
"""
|
|
Wrap a LiteLLM stream into StreamChunk objects.
|
|
|
|
Args:
|
|
stream: LiteLLM async stream
|
|
accumulator: Optional accumulator for tracking
|
|
|
|
Yields:
|
|
StreamChunk objects
|
|
"""
|
|
if accumulator:
|
|
accumulator.start()
|
|
|
|
chunk_id = 0
|
|
try:
|
|
async for chunk in stream:
|
|
chunk_id += 1
|
|
|
|
# Extract data from LiteLLM chunk
|
|
delta = ""
|
|
finish_reason = None
|
|
usage = None
|
|
model = None
|
|
|
|
# Handle different chunk formats
|
|
if hasattr(chunk, "choices") and chunk.choices:
|
|
choice = chunk.choices[0]
|
|
if hasattr(choice, "delta"):
|
|
delta = getattr(choice.delta, "content", "") or ""
|
|
finish_reason = getattr(choice, "finish_reason", None)
|
|
|
|
if hasattr(chunk, "model"):
|
|
model = chunk.model
|
|
|
|
if hasattr(chunk, "usage") and chunk.usage:
|
|
usage = {
|
|
"prompt_tokens": getattr(chunk.usage, "prompt_tokens", 0),
|
|
"completion_tokens": getattr(chunk.usage, "completion_tokens", 0),
|
|
}
|
|
|
|
# Update accumulator
|
|
if accumulator:
|
|
accumulator.add_chunk(
|
|
delta=delta,
|
|
finish_reason=finish_reason,
|
|
model=model,
|
|
usage=usage,
|
|
)
|
|
|
|
# Create StreamChunk
|
|
stream_chunk = StreamChunk(
|
|
id=f"{accumulator.request_id if accumulator else 'stream'}-{chunk_id}",
|
|
delta=delta,
|
|
finish_reason=finish_reason,
|
|
)
|
|
|
|
# Add usage on final chunk
|
|
if finish_reason and accumulator:
|
|
stream_chunk.usage = accumulator.get_usage_stats()
|
|
|
|
yield stream_chunk
|
|
|
|
finally:
|
|
if accumulator:
|
|
accumulator.finish()
|
|
|
|
|
|
def format_sse_chunk(chunk: StreamChunk) -> str:
|
|
"""
|
|
Format a StreamChunk as SSE data.
|
|
|
|
Args:
|
|
chunk: StreamChunk to format
|
|
|
|
Returns:
|
|
SSE-formatted string
|
|
"""
|
|
data: dict[str, Any] = {
|
|
"id": chunk.id,
|
|
"delta": chunk.delta,
|
|
}
|
|
if chunk.finish_reason:
|
|
data["finish_reason"] = chunk.finish_reason
|
|
if chunk.usage:
|
|
data["usage"] = chunk.usage.model_dump()
|
|
|
|
return f"data: {json.dumps(data)}\n\n"
|
|
|
|
|
|
def format_sse_done() -> str:
|
|
"""Format SSE done message."""
|
|
return "data: [DONE]\n\n"
|
|
|
|
|
|
def format_sse_error(error: str, code: str | None = None) -> str:
|
|
"""
|
|
Format an error as SSE data.
|
|
|
|
Args:
|
|
error: Error message
|
|
code: Error code
|
|
|
|
Returns:
|
|
SSE-formatted error string
|
|
"""
|
|
data = {"error": error}
|
|
if code:
|
|
data["code"] = code
|
|
return f"data: {json.dumps(data)}\n\n"
|
|
|
|
|
|
class StreamBuffer:
|
|
"""
|
|
Buffer for streaming responses with backpressure handling.
|
|
|
|
Useful when producing chunks faster than they can be consumed.
|
|
"""
|
|
|
|
def __init__(self, max_size: int = 100) -> None:
|
|
"""
|
|
Initialize buffer.
|
|
|
|
Args:
|
|
max_size: Maximum buffer size
|
|
"""
|
|
self._queue: asyncio.Queue[StreamChunk | None] = asyncio.Queue(maxsize=max_size)
|
|
self._done = False
|
|
self._error: Exception | None = None
|
|
|
|
async def put(self, chunk: StreamChunk) -> None:
|
|
"""
|
|
Put a chunk in the buffer.
|
|
|
|
Args:
|
|
chunk: Chunk to buffer
|
|
"""
|
|
if self._done:
|
|
raise RuntimeError("Buffer is closed")
|
|
await self._queue.put(chunk)
|
|
|
|
async def done(self) -> None:
|
|
"""Signal that streaming is complete."""
|
|
self._done = True
|
|
await self._queue.put(None)
|
|
|
|
async def error(self, err: Exception) -> None:
|
|
"""Signal an error occurred."""
|
|
self._error = err
|
|
self._done = True
|
|
await self._queue.put(None)
|
|
|
|
async def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
|
"""Iterate over buffered chunks."""
|
|
while True:
|
|
chunk = await self._queue.get()
|
|
if chunk is None:
|
|
if self._error:
|
|
raise self._error
|
|
return
|
|
yield chunk
|
|
|
|
|
|
async def stream_to_string(
|
|
stream: AsyncIterator[StreamChunk],
|
|
) -> tuple[str, UsageStats | None]:
|
|
"""
|
|
Consume a stream and return full content.
|
|
|
|
Args:
|
|
stream: Stream to consume
|
|
|
|
Returns:
|
|
Tuple of (content, usage_stats)
|
|
"""
|
|
parts: list[str] = []
|
|
usage: UsageStats | None = None
|
|
|
|
async for chunk in stream:
|
|
if chunk.delta:
|
|
parts.append(chunk.delta)
|
|
if chunk.usage:
|
|
usage = chunk.usage
|
|
|
|
return "".join(parts), usage
|
|
|
|
|
|
async def merge_streams(
|
|
*streams: AsyncIterator[StreamChunk],
|
|
) -> AsyncIterator[StreamChunk]:
|
|
"""
|
|
Merge multiple streams into one.
|
|
|
|
Useful for parallel requests where results should be combined.
|
|
|
|
Args:
|
|
*streams: Streams to merge
|
|
|
|
Yields:
|
|
Chunks from all streams in arrival order
|
|
"""
|
|
pending: set[asyncio.Task[tuple[int, StreamChunk | None]]] = set()
|
|
|
|
async def next_chunk(
|
|
idx: int, stream: AsyncIterator[StreamChunk]
|
|
) -> tuple[int, StreamChunk | None]:
|
|
try:
|
|
return idx, await stream.__anext__()
|
|
except StopAsyncIteration:
|
|
return idx, None
|
|
|
|
# Start initial tasks
|
|
active_streams = list(streams)
|
|
for idx, stream in enumerate(active_streams):
|
|
task = asyncio.create_task(next_chunk(idx, stream))
|
|
pending.add(task)
|
|
|
|
while pending:
|
|
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
for task in done:
|
|
idx, chunk = task.result()
|
|
if chunk is not None:
|
|
yield chunk
|
|
# Schedule next chunk from this stream
|
|
new_task = asyncio.create_task(next_chunk(idx, active_streams[idx]))
|
|
pending.add(new_task)
|