""" Tests for streaming module. """ import asyncio import json from unittest.mock import MagicMock import pytest from models import StreamChunk, UsageStats from streaming import ( StreamAccumulator, StreamBuffer, format_sse_chunk, format_sse_done, format_sse_error, stream_to_string, wrap_litellm_stream, ) class TestStreamAccumulator: """Tests for StreamAccumulator class.""" def test_initial_state(self) -> None: """Test initial accumulator state.""" acc = StreamAccumulator() assert acc.request_id is not None assert acc.content == "" assert acc.chunks_received == 0 assert acc.prompt_tokens == 0 assert acc.completion_tokens == 0 assert acc.model is None assert acc.finish_reason is None def test_custom_request_id(self) -> None: """Test accumulator with custom request ID.""" acc = StreamAccumulator(request_id="custom-id") assert acc.request_id == "custom-id" def test_add_chunk_text(self) -> None: """Test adding text chunks.""" acc = StreamAccumulator() acc.add_chunk("Hello") acc.add_chunk(", ") acc.add_chunk("world!") assert acc.content == "Hello, world!" assert acc.chunks_received == 3 def test_add_chunk_with_finish_reason(self) -> None: """Test adding chunk with finish reason.""" acc = StreamAccumulator() acc.add_chunk("Final", finish_reason="stop") assert acc.finish_reason == "stop" def test_add_chunk_with_model(self) -> None: """Test adding chunk with model info.""" acc = StreamAccumulator() acc.add_chunk("Text", model="claude-opus-4") assert acc.model == "claude-opus-4" def test_add_chunk_with_usage(self) -> None: """Test adding chunk with usage stats.""" acc = StreamAccumulator() acc.add_chunk( "Text", usage={"prompt_tokens": 10, "completion_tokens": 5}, ) assert acc.prompt_tokens == 10 assert acc.completion_tokens == 5 assert acc.total_tokens == 15 def test_start_and_finish(self) -> None: """Test start and finish timing.""" acc = StreamAccumulator() assert acc.duration_seconds is None acc.start() acc.finish() assert acc.duration_seconds is not None assert acc.duration_seconds >= 0 def test_get_usage_stats(self) -> None: """Test getting usage stats.""" acc = StreamAccumulator() acc.add_chunk("", usage={"prompt_tokens": 100, "completion_tokens": 50}) stats = acc.get_usage_stats(cost_usd=0.01) assert stats.prompt_tokens == 100 assert stats.completion_tokens == 50 assert stats.total_tokens == 150 assert stats.cost_usd == 0.01 class TestWrapLiteLLMStream: """Tests for wrap_litellm_stream function.""" async def test_wrap_stream_basic(self) -> None: """Test wrapping a basic stream.""" # Create mock stream chunks async def mock_stream(): chunk1 = MagicMock() chunk1.choices = [MagicMock()] chunk1.choices[0].delta = MagicMock() chunk1.choices[0].delta.content = "Hello" chunk1.choices[0].finish_reason = None chunk1.model = "test-model" chunk1.usage = None yield chunk1 chunk2 = MagicMock() chunk2.choices = [MagicMock()] chunk2.choices[0].delta = MagicMock() chunk2.choices[0].delta.content = " World" chunk2.choices[0].finish_reason = "stop" chunk2.model = "test-model" chunk2.usage = MagicMock() chunk2.usage.prompt_tokens = 5 chunk2.usage.completion_tokens = 2 yield chunk2 accumulator = StreamAccumulator() chunks = [] async for chunk in wrap_litellm_stream(mock_stream(), accumulator): chunks.append(chunk) assert len(chunks) == 2 assert chunks[0].delta == "Hello" assert chunks[1].delta == " World" assert chunks[1].finish_reason == "stop" assert accumulator.content == "Hello World" async def test_wrap_stream_without_accumulator(self) -> None: """Test wrapping stream without accumulator.""" async def mock_stream(): chunk = MagicMock() chunk.choices = [MagicMock()] chunk.choices[0].delta = MagicMock() chunk.choices[0].delta.content = "Test" chunk.choices[0].finish_reason = None chunk.model = None chunk.usage = None yield chunk chunks = [] async for chunk in wrap_litellm_stream(mock_stream()): chunks.append(chunk) assert len(chunks) == 1 class TestSSEFormatting: """Tests for SSE formatting functions.""" def test_format_sse_chunk_basic(self) -> None: """Test formatting basic chunk.""" chunk = StreamChunk(id="chunk-1", delta="Hello") result = format_sse_chunk(chunk) assert result.startswith("data: ") assert result.endswith("\n\n") # Parse the JSON data = json.loads(result[6:-2]) assert data["id"] == "chunk-1" assert data["delta"] == "Hello" def test_format_sse_chunk_with_finish(self) -> None: """Test formatting chunk with finish reason.""" chunk = StreamChunk( id="chunk-2", delta="", finish_reason="stop", ) result = format_sse_chunk(chunk) data = json.loads(result[6:-2]) assert data["finish_reason"] == "stop" def test_format_sse_chunk_with_usage(self) -> None: """Test formatting chunk with usage.""" chunk = StreamChunk( id="chunk-3", delta="", finish_reason="stop", usage=UsageStats( prompt_tokens=10, completion_tokens=5, total_tokens=15, cost_usd=0.001, ), ) result = format_sse_chunk(chunk) data = json.loads(result[6:-2]) assert "usage" in data assert data["usage"]["prompt_tokens"] == 10 def test_format_sse_done(self) -> None: """Test formatting done message.""" result = format_sse_done() assert result == "data: [DONE]\n\n" def test_format_sse_error(self) -> None: """Test formatting error message.""" result = format_sse_error("Something went wrong", code="ERROR_CODE") data = json.loads(result[6:-2]) assert data["error"] == "Something went wrong" assert data["code"] == "ERROR_CODE" def test_format_sse_error_without_code(self) -> None: """Test formatting error without code.""" result = format_sse_error("Error message") data = json.loads(result[6:-2]) assert data["error"] == "Error message" assert "code" not in data class TestStreamBuffer: """Tests for StreamBuffer class.""" async def test_buffer_basic(self) -> None: """Test basic buffer operations.""" buffer = StreamBuffer(max_size=10) # Producer async def produce(): await buffer.put(StreamChunk(id="1", delta="Hello")) await buffer.put(StreamChunk(id="2", delta=" World")) await buffer.done() # Consumer chunks = [] asyncio.create_task(produce()) async for chunk in buffer: chunks.append(chunk) assert len(chunks) == 2 assert chunks[0].delta == "Hello" assert chunks[1].delta == " World" async def test_buffer_error(self) -> None: """Test buffer with error.""" buffer = StreamBuffer() async def produce(): await buffer.put(StreamChunk(id="1", delta="Hello")) await buffer.error(ValueError("Test error")) asyncio.create_task(produce()) with pytest.raises(ValueError, match="Test error"): async for _ in buffer: pass async def test_buffer_put_after_done(self) -> None: """Test putting after done raises.""" buffer = StreamBuffer() await buffer.done() with pytest.raises(RuntimeError, match="closed"): await buffer.put(StreamChunk(id="1", delta="Test")) class TestStreamToString: """Tests for stream_to_string function.""" async def test_stream_to_string_basic(self) -> None: """Test converting stream to string.""" async def mock_stream(): yield StreamChunk(id="1", delta="Hello") yield StreamChunk(id="2", delta=" ") yield StreamChunk(id="3", delta="World") yield StreamChunk( id="4", delta="", finish_reason="stop", usage=UsageStats(prompt_tokens=5, completion_tokens=3), ) content, usage = await stream_to_string(mock_stream()) assert content == "Hello World" assert usage is not None assert usage.prompt_tokens == 5 async def test_stream_to_string_no_usage(self) -> None: """Test stream without usage stats.""" async def mock_stream(): yield StreamChunk(id="1", delta="Test") content, usage = await stream_to_string(mock_stream()) assert content == "Test" assert usage is None