forked from cardosofelipe/fast-next-template
- Add type annotations for mypy compliance - Use UTC-aware datetimes consistently (datetime.now(UTC)) - Add type: ignore comments for LiteLLM incomplete stubs - Fix import ordering and formatting - Update pyproject.toml mypy configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
317 lines
9.3 KiB
Python
317 lines
9.3 KiB
Python
"""
|
|
Tests for streaming module.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from models import StreamChunk, UsageStats
|
|
from streaming import (
|
|
StreamAccumulator,
|
|
StreamBuffer,
|
|
format_sse_chunk,
|
|
format_sse_done,
|
|
format_sse_error,
|
|
stream_to_string,
|
|
wrap_litellm_stream,
|
|
)
|
|
|
|
|
|
class TestStreamAccumulator:
|
|
"""Tests for StreamAccumulator class."""
|
|
|
|
def test_initial_state(self) -> None:
|
|
"""Test initial accumulator state."""
|
|
acc = StreamAccumulator()
|
|
|
|
assert acc.request_id is not None
|
|
assert acc.content == ""
|
|
assert acc.chunks_received == 0
|
|
assert acc.prompt_tokens == 0
|
|
assert acc.completion_tokens == 0
|
|
assert acc.model is None
|
|
assert acc.finish_reason is None
|
|
|
|
def test_custom_request_id(self) -> None:
|
|
"""Test accumulator with custom request ID."""
|
|
acc = StreamAccumulator(request_id="custom-id")
|
|
assert acc.request_id == "custom-id"
|
|
|
|
def test_add_chunk_text(self) -> None:
|
|
"""Test adding text chunks."""
|
|
acc = StreamAccumulator()
|
|
|
|
acc.add_chunk("Hello")
|
|
acc.add_chunk(", ")
|
|
acc.add_chunk("world!")
|
|
|
|
assert acc.content == "Hello, world!"
|
|
assert acc.chunks_received == 3
|
|
|
|
def test_add_chunk_with_finish_reason(self) -> None:
|
|
"""Test adding chunk with finish reason."""
|
|
acc = StreamAccumulator()
|
|
|
|
acc.add_chunk("Final", finish_reason="stop")
|
|
|
|
assert acc.finish_reason == "stop"
|
|
|
|
def test_add_chunk_with_model(self) -> None:
|
|
"""Test adding chunk with model info."""
|
|
acc = StreamAccumulator()
|
|
|
|
acc.add_chunk("Text", model="claude-opus-4")
|
|
|
|
assert acc.model == "claude-opus-4"
|
|
|
|
def test_add_chunk_with_usage(self) -> None:
|
|
"""Test adding chunk with usage stats."""
|
|
acc = StreamAccumulator()
|
|
|
|
acc.add_chunk(
|
|
"Text",
|
|
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
|
)
|
|
|
|
assert acc.prompt_tokens == 10
|
|
assert acc.completion_tokens == 5
|
|
assert acc.total_tokens == 15
|
|
|
|
def test_start_and_finish(self) -> None:
|
|
"""Test start and finish timing."""
|
|
acc = StreamAccumulator()
|
|
|
|
assert acc.duration_seconds is None
|
|
|
|
acc.start()
|
|
acc.finish()
|
|
|
|
assert acc.duration_seconds is not None
|
|
assert acc.duration_seconds >= 0
|
|
|
|
def test_get_usage_stats(self) -> None:
|
|
"""Test getting usage stats."""
|
|
acc = StreamAccumulator()
|
|
acc.add_chunk("", usage={"prompt_tokens": 100, "completion_tokens": 50})
|
|
|
|
stats = acc.get_usage_stats(cost_usd=0.01)
|
|
|
|
assert stats.prompt_tokens == 100
|
|
assert stats.completion_tokens == 50
|
|
assert stats.total_tokens == 150
|
|
assert stats.cost_usd == 0.01
|
|
|
|
|
|
class TestWrapLiteLLMStream:
|
|
"""Tests for wrap_litellm_stream function."""
|
|
|
|
async def test_wrap_stream_basic(self) -> None:
|
|
"""Test wrapping a basic stream."""
|
|
|
|
# Create mock stream chunks
|
|
async def mock_stream():
|
|
chunk1 = MagicMock()
|
|
chunk1.choices = [MagicMock()]
|
|
chunk1.choices[0].delta = MagicMock()
|
|
chunk1.choices[0].delta.content = "Hello"
|
|
chunk1.choices[0].finish_reason = None
|
|
chunk1.model = "test-model"
|
|
chunk1.usage = None
|
|
yield chunk1
|
|
|
|
chunk2 = MagicMock()
|
|
chunk2.choices = [MagicMock()]
|
|
chunk2.choices[0].delta = MagicMock()
|
|
chunk2.choices[0].delta.content = " World"
|
|
chunk2.choices[0].finish_reason = "stop"
|
|
chunk2.model = "test-model"
|
|
chunk2.usage = MagicMock()
|
|
chunk2.usage.prompt_tokens = 5
|
|
chunk2.usage.completion_tokens = 2
|
|
yield chunk2
|
|
|
|
accumulator = StreamAccumulator()
|
|
chunks = []
|
|
|
|
async for chunk in wrap_litellm_stream(mock_stream(), accumulator):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 2
|
|
assert chunks[0].delta == "Hello"
|
|
assert chunks[1].delta == " World"
|
|
assert chunks[1].finish_reason == "stop"
|
|
assert accumulator.content == "Hello World"
|
|
|
|
async def test_wrap_stream_without_accumulator(self) -> None:
|
|
"""Test wrapping stream without accumulator."""
|
|
|
|
async def mock_stream():
|
|
chunk = MagicMock()
|
|
chunk.choices = [MagicMock()]
|
|
chunk.choices[0].delta = MagicMock()
|
|
chunk.choices[0].delta.content = "Test"
|
|
chunk.choices[0].finish_reason = None
|
|
chunk.model = None
|
|
chunk.usage = None
|
|
yield chunk
|
|
|
|
chunks = []
|
|
async for chunk in wrap_litellm_stream(mock_stream()):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 1
|
|
|
|
|
|
class TestSSEFormatting:
|
|
"""Tests for SSE formatting functions."""
|
|
|
|
def test_format_sse_chunk_basic(self) -> None:
|
|
"""Test formatting basic chunk."""
|
|
chunk = StreamChunk(id="chunk-1", delta="Hello")
|
|
result = format_sse_chunk(chunk)
|
|
|
|
assert result.startswith("data: ")
|
|
assert result.endswith("\n\n")
|
|
|
|
# Parse the JSON
|
|
data = json.loads(result[6:-2])
|
|
assert data["id"] == "chunk-1"
|
|
assert data["delta"] == "Hello"
|
|
|
|
def test_format_sse_chunk_with_finish(self) -> None:
|
|
"""Test formatting chunk with finish reason."""
|
|
chunk = StreamChunk(
|
|
id="chunk-2",
|
|
delta="",
|
|
finish_reason="stop",
|
|
)
|
|
result = format_sse_chunk(chunk)
|
|
data = json.loads(result[6:-2])
|
|
|
|
assert data["finish_reason"] == "stop"
|
|
|
|
def test_format_sse_chunk_with_usage(self) -> None:
|
|
"""Test formatting chunk with usage."""
|
|
chunk = StreamChunk(
|
|
id="chunk-3",
|
|
delta="",
|
|
finish_reason="stop",
|
|
usage=UsageStats(
|
|
prompt_tokens=10,
|
|
completion_tokens=5,
|
|
total_tokens=15,
|
|
cost_usd=0.001,
|
|
),
|
|
)
|
|
result = format_sse_chunk(chunk)
|
|
data = json.loads(result[6:-2])
|
|
|
|
assert "usage" in data
|
|
assert data["usage"]["prompt_tokens"] == 10
|
|
|
|
def test_format_sse_done(self) -> None:
|
|
"""Test formatting done message."""
|
|
result = format_sse_done()
|
|
assert result == "data: [DONE]\n\n"
|
|
|
|
def test_format_sse_error(self) -> None:
|
|
"""Test formatting error message."""
|
|
result = format_sse_error("Something went wrong", code="ERROR_CODE")
|
|
data = json.loads(result[6:-2])
|
|
|
|
assert data["error"] == "Something went wrong"
|
|
assert data["code"] == "ERROR_CODE"
|
|
|
|
def test_format_sse_error_without_code(self) -> None:
|
|
"""Test formatting error without code."""
|
|
result = format_sse_error("Error message")
|
|
data = json.loads(result[6:-2])
|
|
|
|
assert data["error"] == "Error message"
|
|
assert "code" not in data
|
|
|
|
|
|
class TestStreamBuffer:
|
|
"""Tests for StreamBuffer class."""
|
|
|
|
async def test_buffer_basic(self) -> None:
|
|
"""Test basic buffer operations."""
|
|
buffer = StreamBuffer(max_size=10)
|
|
|
|
# Producer
|
|
async def produce():
|
|
await buffer.put(StreamChunk(id="1", delta="Hello"))
|
|
await buffer.put(StreamChunk(id="2", delta=" World"))
|
|
await buffer.done()
|
|
|
|
# Consumer
|
|
chunks = []
|
|
asyncio.create_task(produce())
|
|
|
|
async for chunk in buffer:
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 2
|
|
assert chunks[0].delta == "Hello"
|
|
assert chunks[1].delta == " World"
|
|
|
|
async def test_buffer_error(self) -> None:
|
|
"""Test buffer with error."""
|
|
buffer = StreamBuffer()
|
|
|
|
async def produce():
|
|
await buffer.put(StreamChunk(id="1", delta="Hello"))
|
|
await buffer.error(ValueError("Test error"))
|
|
|
|
asyncio.create_task(produce())
|
|
|
|
with pytest.raises(ValueError, match="Test error"):
|
|
async for _ in buffer:
|
|
pass
|
|
|
|
async def test_buffer_put_after_done(self) -> None:
|
|
"""Test putting after done raises."""
|
|
buffer = StreamBuffer()
|
|
await buffer.done()
|
|
|
|
with pytest.raises(RuntimeError, match="closed"):
|
|
await buffer.put(StreamChunk(id="1", delta="Test"))
|
|
|
|
|
|
class TestStreamToString:
|
|
"""Tests for stream_to_string function."""
|
|
|
|
async def test_stream_to_string_basic(self) -> None:
|
|
"""Test converting stream to string."""
|
|
|
|
async def mock_stream():
|
|
yield StreamChunk(id="1", delta="Hello")
|
|
yield StreamChunk(id="2", delta=" ")
|
|
yield StreamChunk(id="3", delta="World")
|
|
yield StreamChunk(
|
|
id="4",
|
|
delta="",
|
|
finish_reason="stop",
|
|
usage=UsageStats(prompt_tokens=5, completion_tokens=3),
|
|
)
|
|
|
|
content, usage = await stream_to_string(mock_stream())
|
|
|
|
assert content == "Hello World"
|
|
assert usage is not None
|
|
assert usage.prompt_tokens == 5
|
|
|
|
async def test_stream_to_string_no_usage(self) -> None:
|
|
"""Test stream without usage stats."""
|
|
|
|
async def mock_stream():
|
|
yield StreamChunk(id="1", delta="Test")
|
|
|
|
content, usage = await stream_to_string(mock_stream())
|
|
|
|
assert content == "Test"
|
|
assert usage is None
|