""" Tests for server module. """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient from config import Settings from models import ModelGroup @pytest.fixture def test_settings() -> Settings: """Test settings with mock API keys.""" return Settings( anthropic_api_key="test-anthropic-key", openai_api_key="test-openai-key", google_api_key="test-google-key", cost_tracking_enabled=False, # Disable for most tests litellm_cache_enabled=False, ) @pytest.fixture def test_client(test_settings: Settings) -> TestClient: """Create test client with mocked dependencies.""" with ( patch("server.get_settings", return_value=test_settings), patch("server.get_provider") as mock_provider, ): mock_provider.return_value = MagicMock() mock_provider.return_value.is_available = True mock_provider.return_value.router = MagicMock() mock_provider.return_value.get_available_models.return_value = {} from server import app return TestClient(app) class TestHealthEndpoint: """Tests for health check endpoint.""" def test_health_check(self, test_client: TestClient) -> None: """Test health check returns healthy status.""" with patch("server.get_settings") as mock_settings: mock_settings.return_value = Settings( anthropic_api_key="test-key", ) with patch("server.get_provider") as mock_provider: mock_provider.return_value = MagicMock() mock_provider.return_value.is_available = True response = test_client.get("/health") assert response.status_code == 200 data = response.json() assert data["status"] == "healthy" assert data["service"] == "llm-gateway" class TestToolDiscoveryEndpoint: """Tests for tool discovery endpoint.""" def test_list_tools(self, test_client: TestClient) -> None: """Test listing available tools.""" response = test_client.get("/mcp/tools") assert response.status_code == 200 data = response.json() assert "tools" in data assert len(data["tools"]) == 4 # 4 tools defined tool_names = [t["name"] for t in data["tools"]] assert "chat_completion" in tool_names assert "list_models" in tool_names assert "get_usage" in tool_names assert "count_tokens" in tool_names def test_tool_has_schema(self, test_client: TestClient) -> None: """Test that tools have input schemas.""" response = test_client.get("/mcp/tools") data = response.json() for tool in data["tools"]: assert "inputSchema" in tool assert "type" in tool["inputSchema"] assert tool["inputSchema"]["type"] == "object" class TestJSONRPCEndpoint: """Tests for JSON-RPC endpoint.""" def test_invalid_jsonrpc_version(self, test_client: TestClient) -> None: """Test invalid JSON-RPC version.""" response = test_client.post( "/mcp", json={ "jsonrpc": "1.0", # Invalid "method": "tools/list", "id": 1, }, ) assert response.status_code == 200 data = response.json() assert "error" in data assert data["error"]["code"] == -32600 def test_tools_list(self, test_client: TestClient) -> None: """Test tools/list method.""" response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "tools/list", "id": 1, }, ) assert response.status_code == 200 data = response.json() assert "result" in data assert "tools" in data["result"] def test_unknown_method(self, test_client: TestClient) -> None: """Test unknown method.""" response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "unknown/method", "id": 1, }, ) assert response.status_code == 200 data = response.json() assert "error" in data assert data["error"]["code"] == -32601 def test_unknown_tool(self, test_client: TestClient) -> None: """Test unknown tool.""" response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "tools/call", "params": { "name": "unknown_tool", "arguments": {}, }, "id": 1, }, ) assert response.status_code == 200 data = response.json() assert "error" in data assert "Unknown tool" in data["error"]["message"] class TestCountTokensTool: """Tests for count_tokens tool.""" def test_count_tokens(self, test_client: TestClient) -> None: """Test counting tokens.""" response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "tools/call", "params": { "name": "count_tokens", "arguments": { "project_id": "proj-123", "agent_id": "agent-456", "text": "Hello, world!", }, }, "id": 1, }, ) assert response.status_code == 200 data = response.json() assert "result" in data def test_count_tokens_with_model(self, test_client: TestClient) -> None: """Test counting tokens with specific model.""" response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "tools/call", "params": { "name": "count_tokens", "arguments": { "project_id": "proj-123", "agent_id": "agent-456", "text": "Hello, world!", "model": "gpt-4", }, }, "id": 1, }, ) assert response.status_code == 200 class TestListModelsTool: """Tests for list_models tool.""" def test_list_all_models(self, test_client: TestClient) -> None: """Test listing all models.""" with patch("server.get_model_router") as mock_router: mock_router.return_value = MagicMock() mock_router.return_value.get_available_models_for_group = AsyncMock( return_value=[] ) response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "tools/call", "params": { "name": "list_models", "arguments": { "project_id": "proj-123", "agent_id": "agent-456", }, }, "id": 1, }, ) assert response.status_code == 200 def test_list_models_by_group(self, test_client: TestClient) -> None: """Test listing models by group.""" with patch("server.get_model_router") as mock_router: mock_router.return_value = MagicMock() mock_router.return_value.parse_model_group.return_value = ( ModelGroup.REASONING ) mock_router.return_value.get_available_models_for_group = AsyncMock( return_value=[] ) response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "tools/call", "params": { "name": "list_models", "arguments": { "project_id": "proj-123", "agent_id": "agent-456", "model_group": "reasoning", }, }, "id": 1, }, ) assert response.status_code == 200 class TestGetUsageTool: """Tests for get_usage tool.""" def test_get_usage(self, test_client: TestClient) -> None: """Test getting usage.""" with patch("server.get_cost_tracker") as mock_tracker: mock_report = MagicMock() mock_report.total_requests = 10 mock_report.total_tokens = 1000 mock_report.total_cost_usd = 0.50 mock_report.by_model = {} mock_report.period_start.isoformat.return_value = "2024-01-01T00:00:00" mock_report.period_end.isoformat.return_value = "2024-01-02T00:00:00" mock_tracker.return_value = MagicMock() mock_tracker.return_value.get_project_usage = AsyncMock( return_value=mock_report ) mock_tracker.return_value.get_agent_usage = AsyncMock( return_value=mock_report ) response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "tools/call", "params": { "name": "get_usage", "arguments": { "project_id": "proj-123", "agent_id": "agent-456", "period": "day", }, }, "id": 1, }, ) assert response.status_code == 200 class TestChatCompletionTool: """Tests for chat_completion tool.""" def test_chat_completion_streaming_not_supported( self, test_client: TestClient, ) -> None: """Test that streaming returns info message.""" with patch("server.get_model_router") as mock_router: mock_router.return_value = MagicMock() mock_router.return_value.select_model = AsyncMock( return_value=("claude-opus-4", MagicMock()) ) with patch("server.get_cost_tracker") as mock_tracker: mock_tracker.return_value = MagicMock() mock_tracker.return_value.check_budget = AsyncMock( return_value=(True, 0.0, 100.0) ) response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "tools/call", "params": { "name": "chat_completion", "arguments": { "project_id": "proj-123", "agent_id": "agent-456", "messages": [{"role": "user", "content": "Hello"}], "stream": True, }, }, "id": 1, }, ) assert response.status_code == 200 def test_chat_completion_success(self, test_client: TestClient) -> None: """Test successful chat completion.""" mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "Hello, world!" mock_response.choices[0].finish_reason = "stop" mock_response.usage = MagicMock() mock_response.usage.prompt_tokens = 10 mock_response.usage.completion_tokens = 5 with patch("server.get_model_router") as mock_router: mock_model_config = MagicMock() mock_model_config.provider.value = "anthropic" mock_router.return_value = MagicMock() mock_router.return_value.select_model = AsyncMock( return_value=("claude-opus-4", mock_model_config) ) with patch("server.get_cost_tracker") as mock_tracker: mock_tracker.return_value = MagicMock() mock_tracker.return_value.check_budget = AsyncMock( return_value=(True, 0.0, 100.0) ) mock_tracker.return_value.record_usage = AsyncMock() with patch("server.get_provider") as mock_prov: mock_prov.return_value = MagicMock() mock_prov.return_value.router = MagicMock() mock_prov.return_value.router.acompletion = AsyncMock( return_value=mock_response ) with patch("server.get_circuit_registry") as mock_reg: mock_circuit = MagicMock() mock_circuit.record_success = AsyncMock() mock_reg.return_value = MagicMock() mock_reg.return_value.get_circuit_sync.return_value = ( mock_circuit ) response = test_client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "tools/call", "params": { "name": "chat_completion", "arguments": { "project_id": "proj-123", "agent_id": "agent-456", "messages": [ {"role": "user", "content": "Hello"} ], }, }, "id": 1, }, ) assert response.status_code == 200