""" Tests for models module. """ from datetime import UTC, datetime import pytest from models import ( AGENT_TYPE_MODEL_PREFERENCES, MODEL_CONFIGS, MODEL_GROUPS, ChatMessage, CompletionRequest, CompletionResponse, CostRecord, EmbeddingRequest, ModelConfig, ModelGroup, ModelGroupConfig, ModelGroupInfo, ModelInfo, Provider, StreamChunk, UsageReport, UsageStats, ) class TestModelGroup: """Tests for ModelGroup enum.""" def test_model_group_values(self) -> None: """Test model group enum values.""" assert ModelGroup.REASONING.value == "reasoning" assert ModelGroup.CODE.value == "code" assert ModelGroup.FAST.value == "fast" assert ModelGroup.VISION.value == "vision" assert ModelGroup.EMBEDDING.value == "embedding" assert ModelGroup.COST_OPTIMIZED.value == "cost_optimized" assert ModelGroup.SELF_HOSTED.value == "self_hosted" def test_model_group_from_string(self) -> None: """Test creating ModelGroup from string.""" assert ModelGroup("reasoning") == ModelGroup.REASONING assert ModelGroup("code") == ModelGroup.CODE assert ModelGroup("fast") == ModelGroup.FAST def test_model_group_invalid(self) -> None: """Test invalid model group value.""" with pytest.raises(ValueError): ModelGroup("invalid_group") class TestProvider: """Tests for Provider enum.""" def test_provider_values(self) -> None: """Test provider enum values.""" assert Provider.ANTHROPIC.value == "anthropic" assert Provider.OPENAI.value == "openai" assert Provider.GOOGLE.value == "google" assert Provider.ALIBABA.value == "alibaba" assert Provider.DEEPSEEK.value == "deepseek" class TestModelConfig: """Tests for ModelConfig dataclass.""" def test_model_config_creation(self) -> None: """Test creating a ModelConfig.""" config = ModelConfig( name="test-model", litellm_name="provider/test-model", provider=Provider.ANTHROPIC, cost_per_1m_input=10.0, cost_per_1m_output=30.0, context_window=100000, max_output_tokens=4096, supports_vision=True, ) assert config.name == "test-model" assert config.provider == Provider.ANTHROPIC assert config.cost_per_1m_input == 10.0 assert config.supports_vision is True assert config.supports_streaming is True # default def test_model_configs_exist(self) -> None: """Test that model configs are defined.""" assert len(MODEL_CONFIGS) > 0 assert "claude-opus-4" in MODEL_CONFIGS assert "gpt-4.1" in MODEL_CONFIGS assert "gemini-2.5-pro" in MODEL_CONFIGS class TestModelGroupConfig: """Tests for ModelGroupConfig dataclass.""" def test_model_group_config_creation(self) -> None: """Test creating a ModelGroupConfig.""" config = ModelGroupConfig( primary="model-a", fallbacks=["model-b", "model-c"], description="Test group", ) assert config.primary == "model-a" assert config.fallbacks == ["model-b", "model-c"] assert config.description == "Test group" def test_get_all_models(self) -> None: """Test getting all models in order.""" config = ModelGroupConfig( primary="model-a", fallbacks=["model-b", "model-c"], description="Test group", ) models = config.get_all_models() assert models == ["model-a", "model-b", "model-c"] def test_model_groups_exist(self) -> None: """Test that model groups are defined.""" assert len(MODEL_GROUPS) > 0 assert ModelGroup.REASONING in MODEL_GROUPS assert ModelGroup.CODE in MODEL_GROUPS assert ModelGroup.FAST in MODEL_GROUPS class TestAgentTypePreferences: """Tests for agent type model preferences.""" def test_agent_preferences_exist(self) -> None: """Test that agent preferences are defined.""" assert len(AGENT_TYPE_MODEL_PREFERENCES) > 0 assert "product_owner" in AGENT_TYPE_MODEL_PREFERENCES assert "software_engineer" in AGENT_TYPE_MODEL_PREFERENCES def test_agent_preference_values(self) -> None: """Test agent preference values.""" assert AGENT_TYPE_MODEL_PREFERENCES["product_owner"] == ModelGroup.REASONING assert AGENT_TYPE_MODEL_PREFERENCES["software_engineer"] == ModelGroup.CODE assert AGENT_TYPE_MODEL_PREFERENCES["devops_engineer"] == ModelGroup.FAST class TestChatMessage: """Tests for ChatMessage model.""" def test_chat_message_creation(self) -> None: """Test creating a ChatMessage.""" msg = ChatMessage(role="user", content="Hello") assert msg.role == "user" assert msg.content == "Hello" assert msg.name is None assert msg.tool_calls is None def test_chat_message_with_optional(self) -> None: """Test ChatMessage with optional fields.""" msg = ChatMessage( role="assistant", content="Response", name="assistant_1", tool_calls=[{"id": "call_1", "function": {"name": "test"}}], ) assert msg.name == "assistant_1" assert msg.tool_calls is not None def test_chat_message_list_content(self) -> None: """Test ChatMessage with list content (for images).""" msg = ChatMessage( role="user", content=[ {"type": "text", "text": "What's in this image?"}, { "type": "image_url", "image_url": {"url": "http://example.com/img.jpg"}, }, ], ) assert isinstance(msg.content, list) class TestCompletionRequest: """Tests for CompletionRequest model.""" def test_completion_request_minimal(self) -> None: """Test minimal CompletionRequest.""" req = CompletionRequest( project_id="proj-123", agent_id="agent-456", messages=[ChatMessage(role="user", content="Hi")], ) assert req.project_id == "proj-123" assert req.agent_id == "agent-456" assert len(req.messages) == 1 assert req.model_group == ModelGroup.REASONING # default assert req.max_tokens == 4096 # default assert req.temperature == 0.7 # default def test_completion_request_full(self) -> None: """Test full CompletionRequest.""" req = CompletionRequest( project_id="proj-123", agent_id="agent-456", messages=[ChatMessage(role="user", content="Hi")], model_group=ModelGroup.CODE, model_override="claude-sonnet-4", max_tokens=8192, temperature=0.5, stream=True, session_id="session-789", metadata={"key": "value"}, ) assert req.model_group == ModelGroup.CODE assert req.model_override == "claude-sonnet-4" assert req.max_tokens == 8192 assert req.stream is True def test_completion_request_validation(self) -> None: """Test CompletionRequest validation.""" with pytest.raises(ValueError): CompletionRequest( project_id="proj-123", agent_id="agent-456", messages=[ChatMessage(role="user", content="Hi")], max_tokens=0, # Invalid ) with pytest.raises(ValueError): CompletionRequest( project_id="proj-123", agent_id="agent-456", messages=[ChatMessage(role="user", content="Hi")], temperature=-0.1, # Invalid ) class TestUsageStats: """Tests for UsageStats model.""" def test_usage_stats_default(self) -> None: """Test default UsageStats.""" stats = UsageStats() assert stats.prompt_tokens == 0 assert stats.completion_tokens == 0 assert stats.total_tokens == 0 assert stats.cost_usd == 0.0 def test_usage_stats_custom(self) -> None: """Test custom UsageStats.""" stats = UsageStats( prompt_tokens=100, completion_tokens=50, total_tokens=150, cost_usd=0.001, ) assert stats.prompt_tokens == 100 assert stats.total_tokens == 150 def test_usage_stats_from_response(self) -> None: """Test creating UsageStats from response.""" config = MODEL_CONFIGS["claude-opus-4"] stats = UsageStats.from_response( prompt_tokens=1000, completion_tokens=500, model_config=config, ) assert stats.prompt_tokens == 1000 assert stats.completion_tokens == 500 assert stats.total_tokens == 1500 # 1000/1M * 15 + 500/1M * 75 = 0.015 + 0.0375 = 0.0525 assert stats.cost_usd == pytest.approx(0.0525, rel=0.01) class TestCompletionResponse: """Tests for CompletionResponse model.""" def test_completion_response_creation(self) -> None: """Test creating a CompletionResponse.""" response = CompletionResponse( id="resp-123", model="claude-opus-4", provider="anthropic", content="Hello, world!", ) assert response.id == "resp-123" assert response.model == "claude-opus-4" assert response.provider == "anthropic" assert response.content == "Hello, world!" assert response.finish_reason == "stop" class TestStreamChunk: """Tests for StreamChunk model.""" def test_stream_chunk_creation(self) -> None: """Test creating a StreamChunk.""" chunk = StreamChunk(id="chunk-1", delta="Hello") assert chunk.id == "chunk-1" assert chunk.delta == "Hello" assert chunk.finish_reason is None def test_stream_chunk_final(self) -> None: """Test final StreamChunk.""" chunk = StreamChunk( id="chunk-last", delta="", finish_reason="stop", usage=UsageStats(prompt_tokens=10, completion_tokens=5, total_tokens=15), ) assert chunk.finish_reason == "stop" assert chunk.usage is not None class TestEmbeddingRequest: """Tests for EmbeddingRequest model.""" def test_embedding_request_creation(self) -> None: """Test creating an EmbeddingRequest.""" req = EmbeddingRequest( project_id="proj-123", agent_id="agent-456", texts=["Hello", "World"], ) assert req.project_id == "proj-123" assert len(req.texts) == 2 assert req.model == "text-embedding-3-large" # default def test_embedding_request_validation(self) -> None: """Test EmbeddingRequest validation.""" with pytest.raises(ValueError): EmbeddingRequest( project_id="proj-123", agent_id="agent-456", texts=[], # Invalid - must have at least 1 ) class TestCostRecord: """Tests for CostRecord dataclass.""" def test_cost_record_creation(self) -> None: """Test creating a CostRecord.""" record = CostRecord( project_id="proj-123", agent_id="agent-456", model="claude-opus-4", prompt_tokens=100, completion_tokens=50, cost_usd=0.01, ) assert record.project_id == "proj-123" assert record.cost_usd == 0.01 assert record.timestamp is not None class TestUsageReport: """Tests for UsageReport model.""" def test_usage_report_creation(self) -> None: """Test creating a UsageReport.""" now = datetime.now(UTC) report = UsageReport( entity_id="proj-123", entity_type="project", period="day", period_start=now, period_end=now, ) assert report.entity_id == "proj-123" assert report.entity_type == "project" assert report.total_requests == 0 assert report.total_cost_usd == 0.0 class TestModelInfo: """Tests for ModelInfo model.""" def test_model_info_from_config(self) -> None: """Test creating ModelInfo from ModelConfig.""" config = MODEL_CONFIGS["claude-opus-4"] info = ModelInfo.from_config(config, available=True) assert info.name == "claude-opus-4" assert info.provider == "anthropic" assert info.available is True assert info.supports_vision is True class TestModelGroupInfo: """Tests for ModelGroupInfo model.""" def test_model_group_info_creation(self) -> None: """Test creating ModelGroupInfo.""" info = ModelGroupInfo( name="reasoning", description="Complex analysis", primary_model="claude-opus-4", fallback_models=["gpt-4.1"], ) assert info.name == "reasoning" assert len(info.fallback_models) == 1