Files
Felipe Cardoso f482559e15 fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 20:56:05 +01:00

412 lines
13 KiB
Python

"""
Tests for models module.
"""
from datetime import UTC, datetime
import pytest
from models import (
AGENT_TYPE_MODEL_PREFERENCES,
MODEL_CONFIGS,
MODEL_GROUPS,
ChatMessage,
CompletionRequest,
CompletionResponse,
CostRecord,
EmbeddingRequest,
ModelConfig,
ModelGroup,
ModelGroupConfig,
ModelGroupInfo,
ModelInfo,
Provider,
StreamChunk,
UsageReport,
UsageStats,
)
class TestModelGroup:
"""Tests for ModelGroup enum."""
def test_model_group_values(self) -> None:
"""Test model group enum values."""
assert ModelGroup.REASONING.value == "reasoning"
assert ModelGroup.CODE.value == "code"
assert ModelGroup.FAST.value == "fast"
assert ModelGroup.VISION.value == "vision"
assert ModelGroup.EMBEDDING.value == "embedding"
assert ModelGroup.COST_OPTIMIZED.value == "cost_optimized"
assert ModelGroup.SELF_HOSTED.value == "self_hosted"
def test_model_group_from_string(self) -> None:
"""Test creating ModelGroup from string."""
assert ModelGroup("reasoning") == ModelGroup.REASONING
assert ModelGroup("code") == ModelGroup.CODE
assert ModelGroup("fast") == ModelGroup.FAST
def test_model_group_invalid(self) -> None:
"""Test invalid model group value."""
with pytest.raises(ValueError):
ModelGroup("invalid_group")
class TestProvider:
"""Tests for Provider enum."""
def test_provider_values(self) -> None:
"""Test provider enum values."""
assert Provider.ANTHROPIC.value == "anthropic"
assert Provider.OPENAI.value == "openai"
assert Provider.GOOGLE.value == "google"
assert Provider.ALIBABA.value == "alibaba"
assert Provider.DEEPSEEK.value == "deepseek"
class TestModelConfig:
"""Tests for ModelConfig dataclass."""
def test_model_config_creation(self) -> None:
"""Test creating a ModelConfig."""
config = ModelConfig(
name="test-model",
litellm_name="provider/test-model",
provider=Provider.ANTHROPIC,
cost_per_1m_input=10.0,
cost_per_1m_output=30.0,
context_window=100000,
max_output_tokens=4096,
supports_vision=True,
)
assert config.name == "test-model"
assert config.provider == Provider.ANTHROPIC
assert config.cost_per_1m_input == 10.0
assert config.supports_vision is True
assert config.supports_streaming is True # default
def test_model_configs_exist(self) -> None:
"""Test that model configs are defined."""
assert len(MODEL_CONFIGS) > 0
assert "claude-opus-4" in MODEL_CONFIGS
assert "gpt-4.1" in MODEL_CONFIGS
assert "gemini-2.5-pro" in MODEL_CONFIGS
class TestModelGroupConfig:
"""Tests for ModelGroupConfig dataclass."""
def test_model_group_config_creation(self) -> None:
"""Test creating a ModelGroupConfig."""
config = ModelGroupConfig(
primary="model-a",
fallbacks=["model-b", "model-c"],
description="Test group",
)
assert config.primary == "model-a"
assert config.fallbacks == ["model-b", "model-c"]
assert config.description == "Test group"
def test_get_all_models(self) -> None:
"""Test getting all models in order."""
config = ModelGroupConfig(
primary="model-a",
fallbacks=["model-b", "model-c"],
description="Test group",
)
models = config.get_all_models()
assert models == ["model-a", "model-b", "model-c"]
def test_model_groups_exist(self) -> None:
"""Test that model groups are defined."""
assert len(MODEL_GROUPS) > 0
assert ModelGroup.REASONING in MODEL_GROUPS
assert ModelGroup.CODE in MODEL_GROUPS
assert ModelGroup.FAST in MODEL_GROUPS
class TestAgentTypePreferences:
"""Tests for agent type model preferences."""
def test_agent_preferences_exist(self) -> None:
"""Test that agent preferences are defined."""
assert len(AGENT_TYPE_MODEL_PREFERENCES) > 0
assert "product_owner" in AGENT_TYPE_MODEL_PREFERENCES
assert "software_engineer" in AGENT_TYPE_MODEL_PREFERENCES
def test_agent_preference_values(self) -> None:
"""Test agent preference values."""
assert AGENT_TYPE_MODEL_PREFERENCES["product_owner"] == ModelGroup.REASONING
assert AGENT_TYPE_MODEL_PREFERENCES["software_engineer"] == ModelGroup.CODE
assert AGENT_TYPE_MODEL_PREFERENCES["devops_engineer"] == ModelGroup.FAST
class TestChatMessage:
"""Tests for ChatMessage model."""
def test_chat_message_creation(self) -> None:
"""Test creating a ChatMessage."""
msg = ChatMessage(role="user", content="Hello")
assert msg.role == "user"
assert msg.content == "Hello"
assert msg.name is None
assert msg.tool_calls is None
def test_chat_message_with_optional(self) -> None:
"""Test ChatMessage with optional fields."""
msg = ChatMessage(
role="assistant",
content="Response",
name="assistant_1",
tool_calls=[{"id": "call_1", "function": {"name": "test"}}],
)
assert msg.name == "assistant_1"
assert msg.tool_calls is not None
def test_chat_message_list_content(self) -> None:
"""Test ChatMessage with list content (for images)."""
msg = ChatMessage(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "http://example.com/img.jpg"},
},
],
)
assert isinstance(msg.content, list)
class TestCompletionRequest:
"""Tests for CompletionRequest model."""
def test_completion_request_minimal(self) -> None:
"""Test minimal CompletionRequest."""
req = CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
)
assert req.project_id == "proj-123"
assert req.agent_id == "agent-456"
assert len(req.messages) == 1
assert req.model_group == ModelGroup.REASONING # default
assert req.max_tokens == 4096 # default
assert req.temperature == 0.7 # default
def test_completion_request_full(self) -> None:
"""Test full CompletionRequest."""
req = CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
model_group=ModelGroup.CODE,
model_override="claude-sonnet-4",
max_tokens=8192,
temperature=0.5,
stream=True,
session_id="session-789",
metadata={"key": "value"},
)
assert req.model_group == ModelGroup.CODE
assert req.model_override == "claude-sonnet-4"
assert req.max_tokens == 8192
assert req.stream is True
def test_completion_request_validation(self) -> None:
"""Test CompletionRequest validation."""
with pytest.raises(ValueError):
CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
max_tokens=0, # Invalid
)
with pytest.raises(ValueError):
CompletionRequest(
project_id="proj-123",
agent_id="agent-456",
messages=[ChatMessage(role="user", content="Hi")],
temperature=-0.1, # Invalid
)
class TestUsageStats:
"""Tests for UsageStats model."""
def test_usage_stats_default(self) -> None:
"""Test default UsageStats."""
stats = UsageStats()
assert stats.prompt_tokens == 0
assert stats.completion_tokens == 0
assert stats.total_tokens == 0
assert stats.cost_usd == 0.0
def test_usage_stats_custom(self) -> None:
"""Test custom UsageStats."""
stats = UsageStats(
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
cost_usd=0.001,
)
assert stats.prompt_tokens == 100
assert stats.total_tokens == 150
def test_usage_stats_from_response(self) -> None:
"""Test creating UsageStats from response."""
config = MODEL_CONFIGS["claude-opus-4"]
stats = UsageStats.from_response(
prompt_tokens=1000,
completion_tokens=500,
model_config=config,
)
assert stats.prompt_tokens == 1000
assert stats.completion_tokens == 500
assert stats.total_tokens == 1500
# 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525
assert stats.cost_usd == pytest.approx(0.0525, rel=0.01)
class TestCompletionResponse:
"""Tests for CompletionResponse model."""
def test_completion_response_creation(self) -> None:
"""Test creating a CompletionResponse."""
response = CompletionResponse(
id="resp-123",
model="claude-opus-4",
provider="anthropic",
content="Hello, world!",
)
assert response.id == "resp-123"
assert response.model == "claude-opus-4"
assert response.provider == "anthropic"
assert response.content == "Hello, world!"
assert response.finish_reason == "stop"
class TestStreamChunk:
"""Tests for StreamChunk model."""
def test_stream_chunk_creation(self) -> None:
"""Test creating a StreamChunk."""
chunk = StreamChunk(id="chunk-1", delta="Hello")
assert chunk.id == "chunk-1"
assert chunk.delta == "Hello"
assert chunk.finish_reason is None
def test_stream_chunk_final(self) -> None:
"""Test final StreamChunk."""
chunk = StreamChunk(
id="chunk-last",
delta="",
finish_reason="stop",
usage=UsageStats(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
assert chunk.finish_reason == "stop"
assert chunk.usage is not None
class TestEmbeddingRequest:
"""Tests for EmbeddingRequest model."""
def test_embedding_request_creation(self) -> None:
"""Test creating an EmbeddingRequest."""
req = EmbeddingRequest(
project_id="proj-123",
agent_id="agent-456",
texts=["Hello", "World"],
)
assert req.project_id == "proj-123"
assert len(req.texts) == 2
assert req.model == "text-embedding-3-large" # default
def test_embedding_request_validation(self) -> None:
"""Test EmbeddingRequest validation."""
with pytest.raises(ValueError):
EmbeddingRequest(
project_id="proj-123",
agent_id="agent-456",
texts=[], # Invalid - must have at least 1
)
class TestCostRecord:
"""Tests for CostRecord dataclass."""
def test_cost_record_creation(self) -> None:
"""Test creating a CostRecord."""
record = CostRecord(
project_id="proj-123",
agent_id="agent-456",
model="claude-opus-4",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.01,
)
assert record.project_id == "proj-123"
assert record.cost_usd == 0.01
assert record.timestamp is not None
class TestUsageReport:
"""Tests for UsageReport model."""
def test_usage_report_creation(self) -> None:
"""Test creating a UsageReport."""
now = datetime.now(UTC)
report = UsageReport(
entity_id="proj-123",
entity_type="project",
period="day",
period_start=now,
period_end=now,
)
assert report.entity_id == "proj-123"
assert report.entity_type == "project"
assert report.total_requests == 0
assert report.total_cost_usd == 0.0
class TestModelInfo:
"""Tests for ModelInfo model."""
def test_model_info_from_config(self) -> None:
"""Test creating ModelInfo from ModelConfig."""
config = MODEL_CONFIGS["claude-opus-4"]
info = ModelInfo.from_config(config, available=True)
assert info.name == "claude-opus-4"
assert info.provider == "anthropic"
assert info.available is True
assert info.supports_vision is True
class TestModelGroupInfo:
"""Tests for ModelGroupInfo model."""
def test_model_group_info_creation(self) -> None:
"""Test creating ModelGroupInfo."""
info = ModelGroupInfo(
name="reasoning",
description="Complex analysis",
primary_model="claude-opus-4",
fallback_models=["gpt-4.1"],
)
assert info.name == "reasoning"
assert len(info.fallback_models) == 1