fix(llm-gateway): improve type safety and datetime consistency

- Add type annotations for mypy compliance
- Use UTC-aware datetimes consistently (datetime.now(UTC))
- Add type: ignore comments for LiteLLM incomplete stubs
- Fix import ordering and formatting
- Update pyproject.toml mypy configuration

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-03 20:56:05 +01:00
parent 6e8b0b022a
commit f482559e15
15 changed files with 111 additions and 105 deletions

View File

@@ -108,19 +108,19 @@ class TestCostTracker:
)
# Verify by getting usage report
report = asyncio.run(
tracker.get_project_usage("proj-123", period="day")
)
report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
assert report.total_requests == 1
assert report.total_cost_usd == pytest.approx(0.01, rel=0.01)
def test_record_usage_disabled(self, tracker_settings: Settings) -> None:
"""Test recording is skipped when disabled."""
settings = Settings(**{
**tracker_settings.model_dump(),
"cost_tracking_enabled": False,
})
settings = Settings(
**{
**tracker_settings.model_dump(),
"cost_tracking_enabled": False,
}
)
fake_redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
disabled_tracker = CostTracker(redis_client=fake_redis, settings=settings)
@@ -157,9 +157,7 @@ class TestCostTracker:
)
# Verify session usage
session_usage = asyncio.run(
tracker.get_session_usage("session-789")
)
session_usage = asyncio.run(tracker.get_session_usage("session-789"))
assert session_usage["session_id"] == "session-789"
assert session_usage["total_cost_usd"] == pytest.approx(0.01, rel=0.01)
@@ -199,9 +197,7 @@ class TestCostTracker:
)
)
report = asyncio.run(
tracker.get_project_usage("proj-123", period="day")
)
report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
assert report.total_requests == 2
assert len(report.by_model) == 2
@@ -221,9 +217,7 @@ class TestCostTracker:
)
)
report = asyncio.run(
tracker.get_agent_usage("agent-456", period="day")
)
report = asyncio.run(tracker.get_agent_usage("agent-456", period="day"))
assert report.entity_id == "agent-456"
assert report.entity_type == "agent"
@@ -243,12 +237,8 @@ class TestCostTracker:
)
# Check different periods
hour_report = asyncio.run(
tracker.get_project_usage("proj-123", period="hour")
)
day_report = asyncio.run(
tracker.get_project_usage("proj-123", period="day")
)
hour_report = asyncio.run(tracker.get_project_usage("proj-123", period="hour"))
day_report = asyncio.run(tracker.get_project_usage("proj-123", period="day"))
month_report = asyncio.run(
tracker.get_project_usage("proj-123", period="month")
)
@@ -306,9 +296,7 @@ class TestCostTracker:
def test_check_budget_default_limit(self, tracker: CostTracker) -> None:
"""Test budget check with default limit."""
within, current, limit = asyncio.run(
tracker.check_budget("proj-123")
)
within, current, limit = asyncio.run(tracker.check_budget("proj-123"))
assert limit == 1000.0 # Default from settings

View File

@@ -2,7 +2,6 @@
Tests for exceptions module.
"""
from exceptions import (
AllProvidersFailedError,
CircuitOpenError,

View File

@@ -172,7 +172,10 @@ class TestChatMessage:
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/img.jpg"}},
{
"type": "image_url",
"image_url": {"url": "http://example.com/img.jpg"},
},
],
)
assert isinstance(msg.content, list)

View File

@@ -79,14 +79,23 @@ class TestModelRouter:
def test_get_preferred_group_for_agent(self, router: ModelRouter) -> None:
"""Test getting preferred group for agent types."""
assert router.get_preferred_group_for_agent("product_owner") == ModelGroup.REASONING
assert router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
assert router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
assert (
router.get_preferred_group_for_agent("product_owner")
== ModelGroup.REASONING
)
assert (
router.get_preferred_group_for_agent("software_engineer") == ModelGroup.CODE
)
assert (
router.get_preferred_group_for_agent("devops_engineer") == ModelGroup.FAST
)
def test_get_preferred_group_unknown_agent(self, router: ModelRouter) -> None:
"""Test getting preferred group for unknown agent."""
# Should default to REASONING
assert router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
assert (
router.get_preferred_group_for_agent("unknown_type") == ModelGroup.REASONING
)
def test_select_model_by_group(self, router: ModelRouter) -> None:
"""Test selecting model by group."""
@@ -99,9 +108,7 @@ class TestModelRouter:
def test_select_model_by_group_string(self, router: ModelRouter) -> None:
"""Test selecting model by group string."""
model_name, config = asyncio.run(
router.select_model(model_group="code")
)
model_name, config = asyncio.run(router.select_model(model_group="code"))
assert model_name == "claude-sonnet-4"
@@ -174,9 +181,7 @@ class TestModelRouter:
limited_router = ModelRouter(settings=settings, circuit_registry=registry)
with pytest.raises(AllProvidersFailedError) as exc_info:
asyncio.run(
limited_router.select_model(model_group=ModelGroup.REASONING)
)
asyncio.run(limited_router.select_model(model_group=ModelGroup.REASONING))
assert exc_info.value.model_group == "reasoning"
assert len(exc_info.value.attempted_models) > 0
@@ -195,18 +200,14 @@ class TestModelRouter:
def test_get_available_models_for_group_string(self, router: ModelRouter) -> None:
"""Test getting available models with string group."""
models = asyncio.run(
router.get_available_models_for_group("code")
)
models = asyncio.run(router.get_available_models_for_group("code"))
assert len(models) > 0
def test_get_available_models_invalid_group(self, router: ModelRouter) -> None:
"""Test getting models for invalid group."""
with pytest.raises(InvalidModelGroupError):
asyncio.run(
router.get_available_models_for_group("invalid")
)
asyncio.run(router.get_available_models_for_group("invalid"))
def test_get_all_model_groups(self, router: ModelRouter) -> None:
"""Test getting all model groups info."""

View File

@@ -36,6 +36,7 @@ def test_client(test_settings: Settings) -> TestClient:
mock_provider.return_value.get_available_models.return_value = {}
from server import app
return TestClient(app)
@@ -243,7 +244,9 @@ class TestListModelsTool:
"""Test listing models by group."""
with patch("server.get_model_router") as mock_router:
mock_router.return_value = MagicMock()
mock_router.return_value.parse_model_group.return_value = ModelGroup.REASONING
mock_router.return_value.parse_model_group.return_value = (
ModelGroup.REASONING
)
mock_router.return_value.get_available_models_for_group = AsyncMock(
return_value=[]
)
@@ -340,9 +343,7 @@ class TestChatCompletionTool:
"arguments": {
"project_id": "proj-123",
"agent_id": "agent-456",
"messages": [
{"role": "user", "content": "Hello"}
],
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
},
},
@@ -388,7 +389,9 @@ class TestChatCompletionTool:
mock_circuit = MagicMock()
mock_circuit.record_success = AsyncMock()
mock_reg.return_value = MagicMock()
mock_reg.return_value.get_circuit_sync.return_value = mock_circuit
mock_reg.return_value.get_circuit_sync.return_value = (
mock_circuit
)
response = test_client.post(
"/mcp",

View File

@@ -110,6 +110,7 @@ class TestWrapLiteLLMStream:
async def test_wrap_stream_basic(self) -> None:
"""Test wrapping a basic stream."""
# Create mock stream chunks
async def mock_stream():
chunk1 = MagicMock()
@@ -146,6 +147,7 @@ class TestWrapLiteLLMStream:
async def test_wrap_stream_without_accumulator(self) -> None:
"""Test wrapping stream without accumulator."""
async def mock_stream():
chunk = MagicMock()
chunk.choices = [MagicMock()]
@@ -284,6 +286,7 @@ class TestStreamToString:
async def test_stream_to_string_basic(self) -> None:
"""Test converting stream to string."""
async def mock_stream():
yield StreamChunk(id="1", delta="Hello")
yield StreamChunk(id="2", delta=" ")
@@ -303,6 +306,7 @@ class TestStreamToString:
async def test_stream_to_string_no_usage(self) -> None:
"""Test stream without usage stats."""
async def mock_stream():
yield StreamChunk(id="1", delta="Test")