fix(llm-gateway): improve type safety and datetime consistency
- Add type annotations for mypy compliance - Use UTC-aware datetimes consistently (datetime.now(UTC)) - Add type: ignore comments for LiteLLM incomplete stubs - Fix import ordering and formatting - Update pyproject.toml mypy configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -108,19 +108,19 @@ class TestCostTracker:
|
||||
)
|
||||
|
||||
# Verify by getting usage report
|
||||
report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="day")
|
||||
)
|
||||
report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
|
||||
|
||||
assert report.total_requests == 1
|
||||
assert report.total_cost_usd == pytest.approx(0.01, rel=0.01)
|
||||
|
||||
def test_record_usage_disabled(self, tracker_settings: Settings) -> None:
|
||||
"""Test recording is skipped when disabled."""
|
||||
settings = Settings(**{
|
||||
**tracker_settings.model_dump(),
|
||||
"cost_tracking_enabled": False,
|
||||
})
|
||||
settings = Settings(
|
||||
**{
|
||||
**tracker_settings.model_dump(),
|
||||
"cost_tracking_enabled": False,
|
||||
}
|
||||
)
|
||||
fake_redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
|
||||
disabled_tracker = CostTracker(redis_client=fake_redis, settings=settings)
|
||||
|
||||
@@ -157,9 +157,7 @@ class TestCostTracker:
|
||||
)
|
||||
|
||||
# Verify session usage
|
||||
session_usage = asyncio.run(
|
||||
tracker.get_session_usage("session-789")
|
||||
)
|
||||
session_usage = asyncio.run(tracker.get_session_usage("session-789"))
|
||||
|
||||
assert session_usage["session_id"] == "session-789"
|
||||
assert session_usage["total_cost_usd"] == pytest.approx(0.01, rel=0.01)
|
||||
@@ -199,9 +197,7 @@ class TestCostTracker:
|
||||
)
|
||||
)
|
||||
|
||||
report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="day")
|
||||
)
|
||||
report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
|
||||
|
||||
assert report.total_requests == 2
|
||||
assert len(report.by_model) == 2
|
||||
@@ -221,9 +217,7 @@ class TestCostTracker:
|
||||
)
|
||||
)
|
||||
|
||||
report = asyncio.run(
|
||||
tracker.get_agent_usage("agent-456", period="day")
|
||||
)
|
||||
report = asyncio.run(tracker.get_agent_usage("agent-456", period="day"))
|
||||
|
||||
assert report.entity_id == "agent-456"
|
||||
assert report.entity_type == "agent"
|
||||
@@ -243,12 +237,8 @@ class TestCostTracker:
|
||||
)
|
||||
|
||||
# Check different periods
|
||||
hour_report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="hour")
|
||||
)
|
||||
day_report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="day")
|
||||
)
|
||||
hour_report = asyncio.run(tracker.get_project_usage("proj-123", period="hour"))
|
||||
day_report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
|
||||
month_report = asyncio.run(
|
||||
tracker.get_project_usage("proj-123", period="month")
|
||||
)
|
||||
@@ -306,9 +296,7 @@ class TestCostTracker:
|
||||
|
||||
def test_check_budget_default_limit(self, tracker: CostTracker) -> None:
|
||||
"""Test budget check with default limit."""
|
||||
within, current, limit = asyncio.run(
|
||||
tracker.check_budget("proj-123")
|
||||
)
|
||||
within, current, limit = asyncio.run(tracker.check_budget("proj-123"))
|
||||
|
||||
assert limit == 1000.0 # Default from settings
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Tests for exceptions module.
|
||||
"""
|
||||
|
||||
|
||||
from exceptions import (
|
||||
AllProvidersFailedError,
|
||||
CircuitOpenError,
|
||||
|
||||
@@ -172,7 +172,10 @@ class TestChatMessage:
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/img.jpg"}},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "http://example.com/img.jpg"},
|
||||
},
|
||||
],
|
||||
)
|
||||
assert isinstance(msg.content, list)
|
||||
|
||||
@@ -79,14 +79,23 @@ class TestModelRouter:
|
||||
|
||||
def test_get_preferred_group_for_agent(self, router: ModelRouter) -> None:
|
||||
"""Test getting preferred group for agent types."""
|
||||
assert router.get_preferred_group_for_agent("product_owner") == ModelGroup.REASONING
|
||||
assert router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
|
||||
assert router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
|
||||
assert (
|
||||
router.get_preferred_group_for_agent("product_owner")
|
||||
== ModelGroup.REASONING
|
||||
)
|
||||
assert (
|
||||
router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
|
||||
)
|
||||
assert (
|
||||
router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
|
||||
)
|
||||
|
||||
def test_get_preferred_group_unknown_agent(self, router: ModelRouter) -> None:
|
||||
"""Test getting preferred group for unknown agent."""
|
||||
# Should default to REASONING
|
||||
assert router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
|
||||
assert (
|
||||
router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
|
||||
)
|
||||
|
||||
def test_select_model_by_group(self, router: ModelRouter) -> None:
|
||||
"""Test selecting model by group."""
|
||||
@@ -99,9 +108,7 @@ class TestModelRouter:
|
||||
|
||||
def test_select_model_by_group_string(self, router: ModelRouter) -> None:
|
||||
"""Test selecting model by group string."""
|
||||
model_name, config = asyncio.run(
|
||||
router.select_model(model_group="code")
|
||||
)
|
||||
model_name, config = asyncio.run(router.select_model(model_group="code"))
|
||||
|
||||
assert model_name == "claude-sonnet-4"
|
||||
|
||||
@@ -174,9 +181,7 @@ class TestModelRouter:
|
||||
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
|
||||
|
||||
with pytest.raises(AllProvidersFailedError) as exc_info:
|
||||
asyncio.run(
|
||||
limited_router.select_model(model_group=ModelGroup.REASONING)
|
||||
)
|
||||
asyncio.run(limited_router.select_model(model_group=ModelGroup.REASONING))
|
||||
|
||||
assert exc_info.value.model_group == "reasoning"
|
||||
assert len(exc_info.value.attempted_models) > 0
|
||||
@@ -195,18 +200,14 @@ class TestModelRouter:
|
||||
|
||||
def test_get_available_models_for_group_string(self, router: ModelRouter) -> None:
|
||||
"""Test getting available models with string group."""
|
||||
models = asyncio.run(
|
||||
router.get_available_models_for_group("code")
|
||||
)
|
||||
models = asyncio.run(router.get_available_models_for_group("code"))
|
||||
|
||||
assert len(models) > 0
|
||||
|
||||
def test_get_available_models_invalid_group(self, router: ModelRouter) -> None:
|
||||
"""Test getting models for invalid group."""
|
||||
with pytest.raises(InvalidModelGroupError):
|
||||
asyncio.run(
|
||||
router.get_available_models_for_group("invalid")
|
||||
)
|
||||
asyncio.run(router.get_available_models_for_group("invalid"))
|
||||
|
||||
def test_get_all_model_groups(self, router: ModelRouter) -> None:
|
||||
"""Test getting all model groups info."""
|
||||
|
||||
@@ -36,6 +36,7 @@ def test_client(test_settings: Settings) -> TestClient:
|
||||
mock_provider.return_value.get_available_models.return_value = {}
|
||||
|
||||
from server import app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@@ -243,7 +244,9 @@ class TestListModelsTool:
|
||||
"""Test listing models by group."""
|
||||
with patch("server.get_model_router") as mock_router:
|
||||
mock_router.return_value = MagicMock()
|
||||
mock_router.return_value.parse_model_group.return_value = ModelGroup.REASONING
|
||||
mock_router.return_value.parse_model_group.return_value = (
|
||||
ModelGroup.REASONING
|
||||
)
|
||||
mock_router.return_value.get_available_models_for_group = AsyncMock(
|
||||
return_value=[]
|
||||
)
|
||||
@@ -340,9 +343,7 @@ class TestChatCompletionTool:
|
||||
"arguments": {
|
||||
"project_id": "proj-123",
|
||||
"agent_id": "agent-456",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
},
|
||||
},
|
||||
@@ -388,7 +389,9 @@ class TestChatCompletionTool:
|
||||
mock_circuit = MagicMock()
|
||||
mock_circuit.record_success = AsyncMock()
|
||||
mock_reg.return_value = MagicMock()
|
||||
mock_reg.return_value.get_circuit_sync.return_value = mock_circuit
|
||||
mock_reg.return_value.get_circuit_sync.return_value = (
|
||||
mock_circuit
|
||||
)
|
||||
|
||||
response = test_client.post(
|
||||
"/mcp",
|
||||
|
||||
@@ -110,6 +110,7 @@ class TestWrapLiteLLMStream:
|
||||
|
||||
async def test_wrap_stream_basic(self) -> None:
|
||||
"""Test wrapping a basic stream."""
|
||||
|
||||
# Create mock stream chunks
|
||||
async def mock_stream():
|
||||
chunk1 = MagicMock()
|
||||
@@ -146,6 +147,7 @@ class TestWrapLiteLLMStream:
|
||||
|
||||
async def test_wrap_stream_without_accumulator(self) -> None:
|
||||
"""Test wrapping stream without accumulator."""
|
||||
|
||||
async def mock_stream():
|
||||
chunk = MagicMock()
|
||||
chunk.choices = [MagicMock()]
|
||||
@@ -284,6 +286,7 @@ class TestStreamToString:
|
||||
|
||||
async def test_stream_to_string_basic(self) -> None:
|
||||
"""Test converting stream to string."""
|
||||
|
||||
async def mock_stream():
|
||||
yield StreamChunk(id="1", delta="Hello")
|
||||
yield StreamChunk(id="2", delta=" ")
|
||||
@@ -303,6 +306,7 @@ class TestStreamToString:
|
||||
|
||||
async def test_stream_to_string_no_usage(self) -> None:
|
||||
"""Test stream without usage stats."""
|
||||
|
||||
async def mock_stream():
|
||||
yield StreamChunk(id="1", delta="Test")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user