"""Tests for server module and MCP tools.""" from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock import pytest from fastapi.testclient import TestClient class TestHealthCheck: """Tests for health check endpoint.""" @pytest.mark.asyncio async def test_health_check_healthy(self): """Test health check when all dependencies are connected.""" import server # Create a proper async context manager mock for database mock_conn = AsyncMock() mock_conn.fetchval = AsyncMock(return_value=1) mock_db = MagicMock() mock_db._pool = MagicMock() # Make acquire an async context manager mock_cm = AsyncMock() mock_cm.__aenter__.return_value = mock_conn mock_cm.__aexit__.return_value = None mock_db.acquire.return_value = mock_cm # Mock Redis mock_redis = AsyncMock() mock_redis.ping = AsyncMock(return_value=True) # Mock HTTP client for LLM Gateway mock_http_response = AsyncMock() mock_http_response.status_code = 200 mock_http_client = AsyncMock() mock_http_client.get = AsyncMock(return_value=mock_http_response) # Mock embeddings with Redis and HTTP client mock_embeddings = MagicMock() mock_embeddings._redis = mock_redis mock_embeddings._http_client = mock_http_client server._database = mock_db server._embeddings = mock_embeddings result = await server.health_check() assert result["status"] == "healthy" assert result["service"] == "knowledge-base" assert result["dependencies"]["database"] == "connected" assert result["dependencies"]["redis"] == "connected" assert result["dependencies"]["llm_gateway"] == "connected" @pytest.mark.asyncio async def test_health_check_no_database(self): """Test health check without database - should be unhealthy.""" import server server._database = None server._embeddings = None result = await server.health_check() assert result["status"] == "unhealthy" assert result["dependencies"]["database"] == "not initialized" @pytest.mark.asyncio async def test_health_check_degraded(self): """Test health check with database but no Redis - should be degraded.""" import server # Create a proper async context manager mock for database mock_conn = AsyncMock() mock_conn.fetchval = AsyncMock(return_value=1) mock_db = MagicMock() mock_db._pool = MagicMock() mock_cm = AsyncMock() mock_cm.__aenter__.return_value = mock_conn mock_cm.__aexit__.return_value = None mock_db.acquire.return_value = mock_cm # Mock embeddings without Redis mock_embeddings = MagicMock() mock_embeddings._redis = None mock_embeddings._http_client = None server._database = mock_db server._embeddings = mock_embeddings result = await server.health_check() assert result["status"] == "degraded" assert result["dependencies"]["database"] == "connected" assert result["dependencies"]["redis"] == "not initialized" class TestSearchKnowledgeTool: """Tests for search_knowledge MCP tool.""" @pytest.mark.asyncio async def test_search_success(self): """Test successful search.""" import server from models import SearchResponse, SearchResult mock_search = MagicMock() mock_search.search = AsyncMock( return_value=SearchResponse( query="test query", search_type="hybrid", results=[ SearchResult( id="id-1", content="Test content", score=0.95, source_path="/test/file.py", chunk_type="code", collection="default", ) ], total_results=1, search_time_ms=10.5, ) ) server._search = mock_search # Call the wrapped function via .fn result = await server.search_knowledge.fn( project_id="proj-123", agent_id="agent-456", query="test query", search_type="hybrid", collection=None, limit=10, threshold=0.7, file_types=None, ) assert result["success"] is True assert len(result["results"]) == 1 assert result["results"][0]["score"] == 0.95 @pytest.mark.asyncio async def test_search_invalid_type(self): """Test search with invalid search type.""" import server result = await server.search_knowledge.fn( project_id="proj-123", agent_id="agent-456", query="test", search_type="invalid", ) assert result["success"] is False assert "Invalid search type" in result["error"] @pytest.mark.asyncio async def test_search_invalid_file_type(self): """Test search with invalid file type.""" import server result = await server.search_knowledge.fn( project_id="proj-123", agent_id="agent-456", query="test", search_type="hybrid", collection=None, limit=10, threshold=0.7, file_types=["invalid_type"], ) assert result["success"] is False assert "Invalid file type" in result["error"] class TestIngestContentTool: """Tests for ingest_content MCP tool.""" @pytest.mark.asyncio async def test_ingest_success(self): """Test successful ingestion.""" import server from models import IngestResult mock_collections = MagicMock() mock_collections.ingest = AsyncMock( return_value=IngestResult( success=True, chunks_created=3, embeddings_generated=3, source_path="/test/file.py", collection="default", chunk_ids=["id-1", "id-2", "id-3"], ) ) server._collections = mock_collections result = await server.ingest_content.fn( project_id="proj-123", agent_id="agent-456", content="def hello(): pass", source_path="/test/file.py", collection="default", chunk_type="text", file_type=None, metadata=None, ) assert result["success"] is True assert result["chunks_created"] == 3 assert len(result["chunk_ids"]) == 3 @pytest.mark.asyncio async def test_ingest_invalid_chunk_type(self): """Test ingest with invalid chunk type.""" import server result = await server.ingest_content.fn( project_id="proj-123", agent_id="agent-456", content="test content", chunk_type="invalid", ) assert result["success"] is False assert "Invalid chunk type" in result["error"] @pytest.mark.asyncio async def test_ingest_invalid_file_type(self): """Test ingest with invalid file type.""" import server result = await server.ingest_content.fn( project_id="proj-123", agent_id="agent-456", content="test content", source_path=None, collection="default", chunk_type="text", file_type="invalid", metadata=None, ) assert result["success"] is False assert "Invalid file type" in result["error"] class TestDeleteContentTool: """Tests for delete_content MCP tool.""" @pytest.mark.asyncio async def test_delete_success(self): """Test successful deletion.""" import server from models import DeleteResult mock_collections = MagicMock() mock_collections.delete = AsyncMock( return_value=DeleteResult( success=True, chunks_deleted=5, ) ) server._collections = mock_collections result = await server.delete_content.fn( project_id="proj-123", agent_id="agent-456", source_path="/test/file.py", collection=None, chunk_ids=None, ) assert result["success"] is True assert result["chunks_deleted"] == 5 class TestListCollectionsTool: """Tests for list_collections MCP tool.""" @pytest.mark.asyncio async def test_list_collections_success(self): """Test listing collections.""" import server from models import CollectionInfo, ListCollectionsResponse mock_collections = MagicMock() mock_collections.list_collections = AsyncMock( return_value=ListCollectionsResponse( project_id="proj-123", collections=[ CollectionInfo( name="collection-1", project_id="proj-123", chunk_count=100, total_tokens=50000, file_types=["python"], created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ) ], total_collections=1, ) ) server._collections = mock_collections result = await server.list_collections.fn( project_id="proj-123", agent_id="agent-456", ) assert result["success"] is True assert result["total_collections"] == 1 assert len(result["collections"]) == 1 class TestGetCollectionStatsTool: """Tests for get_collection_stats MCP tool.""" @pytest.mark.asyncio async def test_get_stats_success(self): """Test getting collection stats.""" import server from models import CollectionStats mock_collections = MagicMock() mock_collections.get_collection_stats = AsyncMock( return_value=CollectionStats( collection="test-collection", project_id="proj-123", chunk_count=100, unique_sources=10, total_tokens=50000, avg_chunk_size=500.0, chunk_types={"code": 60, "text": 40}, file_types={"python": 50, "javascript": 10}, ) ) server._collections = mock_collections result = await server.get_collection_stats.fn( project_id="proj-123", agent_id="agent-456", collection="test-collection", ) assert result["success"] is True assert result["chunk_count"] == 100 assert result["unique_sources"] == 10 class TestUpdateDocumentTool: """Tests for update_document MCP tool.""" @pytest.mark.asyncio async def test_update_success(self): """Test updating a document.""" import server from models import IngestResult mock_collections = MagicMock() mock_collections.update_document = AsyncMock( return_value=IngestResult( success=True, chunks_created=2, embeddings_generated=2, source_path="/test/file.py", collection="default", chunk_ids=["id-1", "id-2"], ) ) server._collections = mock_collections result = await server.update_document.fn( project_id="proj-123", agent_id="agent-456", source_path="/test/file.py", content="def updated(): pass", collection="default", chunk_type="text", file_type=None, metadata=None, ) assert result["success"] is True assert result["chunks_created"] == 2 @pytest.mark.asyncio async def test_update_invalid_chunk_type(self): """Test update with invalid chunk type.""" import server result = await server.update_document.fn( project_id="proj-123", agent_id="agent-456", source_path="/test/file.py", content="test", chunk_type="invalid", ) assert result["success"] is False assert "Invalid chunk type" in result["error"] class TestMCPToolsEndpoint: """Tests for /mcp/tools endpoint.""" def test_list_mcp_tools(self): """Test listing available MCP tools.""" import server client = TestClient(server.app) response = client.get("/mcp/tools") assert response.status_code == 200 data = response.json() assert "tools" in data assert len(data["tools"]) == 6 # 6 tools registered tool_names = [t["name"] for t in data["tools"]] assert "search_knowledge" in tool_names assert "ingest_content" in tool_names assert "delete_content" in tool_names assert "list_collections" in tool_names assert "get_collection_stats" in tool_names assert "update_document" in tool_names def test_tool_has_schema(self): """Test that each tool has input schema.""" import server client = TestClient(server.app) response = client.get("/mcp/tools") data = response.json() for tool in data["tools"]: assert "inputSchema" in tool assert "type" in tool["inputSchema"] assert tool["inputSchema"]["type"] == "object" class TestMCPRPCEndpoint: """Tests for /mcp JSON-RPC endpoint.""" def test_valid_jsonrpc_request(self): """Test valid JSON-RPC request.""" import server from models import SearchResponse, SearchResult mock_search = MagicMock() mock_search.search = AsyncMock( return_value=SearchResponse( query="test", search_type="hybrid", results=[ SearchResult( id="id-1", content="Test", score=0.9, source_path="/test.py", chunk_type="code", collection="default", ) ], total_results=1, search_time_ms=5.0, ) ) server._search = mock_search client = TestClient(server.app) response = client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "search_knowledge", "params": { "project_id": "proj-123", "agent_id": "agent-456", "query": "test", }, "id": 1, }, ) assert response.status_code == 200 data = response.json() assert data["jsonrpc"] == "2.0" assert data["id"] == 1 assert "result" in data assert data["result"]["success"] is True def test_invalid_jsonrpc_version(self): """Test request with invalid JSON-RPC version.""" import server client = TestClient(server.app) response = client.post( "/mcp", json={ "jsonrpc": "1.0", "method": "search_knowledge", "params": {}, "id": 1, }, ) assert response.status_code == 400 data = response.json() assert data["error"]["code"] == -32600 assert "jsonrpc must be '2.0'" in data["error"]["message"] def test_missing_method(self): """Test request without method.""" import server client = TestClient(server.app) response = client.post( "/mcp", json={ "jsonrpc": "2.0", "params": {}, "id": 1, }, ) assert response.status_code == 400 data = response.json() assert data["error"]["code"] == -32600 assert "method is required" in data["error"]["message"] def test_unknown_method(self): """Test request with unknown method.""" import server client = TestClient(server.app) response = client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "unknown_method", "params": {}, "id": 1, }, ) assert response.status_code == 404 data = response.json() assert data["error"]["code"] == -32601 assert "Method not found" in data["error"]["message"] def test_invalid_params(self): """Test request with invalid params.""" import server client = TestClient(server.app) response = client.post( "/mcp", json={ "jsonrpc": "2.0", "method": "search_knowledge", "params": {"invalid_param": "value"}, # Missing required params "id": 1, }, ) assert response.status_code == 400 data = response.json() assert data["error"]["code"] == -32602 class TestContentSizeLimits: """Tests for content size validation.""" @pytest.mark.asyncio async def test_ingest_rejects_oversized_content(self): """Test that ingest rejects content exceeding size limit.""" import server from config import get_settings settings = get_settings() # Create content larger than max size oversized_content = "x" * (settings.max_document_size + 1) result = await server.ingest_content.fn( project_id="proj-123", agent_id="agent-456", content=oversized_content, chunk_type="text", ) assert result["success"] is False assert "exceeds maximum" in result["error"] @pytest.mark.asyncio async def test_update_rejects_oversized_content(self): """Test that update rejects content exceeding size limit.""" import server from config import get_settings settings = get_settings() oversized_content = "x" * (settings.max_document_size + 1) result = await server.update_document.fn( project_id="proj-123", agent_id="agent-456", source_path="/test.py", content=oversized_content, chunk_type="text", ) assert result["success"] is False assert "exceeds maximum" in result["error"] @pytest.mark.asyncio async def test_ingest_accepts_valid_size_content(self): """Test that ingest accepts content within size limit.""" import server from models import IngestResult mock_collections = MagicMock() mock_collections.ingest = AsyncMock( return_value=IngestResult( success=True, chunks_created=1, embeddings_generated=1, source_path="/test.py", collection="default", chunk_ids=["id-1"], ) ) server._collections = mock_collections # Small content that's within limits # Pass all parameters to avoid Field default resolution issues result = await server.ingest_content.fn( project_id="proj-123", agent_id="agent-456", content="def hello(): pass", source_path="/test.py", collection="default", chunk_type="text", file_type=None, metadata=None, ) assert result["success"] is True