Files
syndarix/mcp-servers/knowledge-base/tests/test_server.py
Felipe Cardoso 953af52d0e fix(mcp-kb): address critical issues from deep review
- Fix SQL HAVING clause bug by using CTE approach (closes #73)
- Add /mcp JSON-RPC 2.0 endpoint for tool execution (closes #74)
- Add /mcp/tools endpoint for tool discovery (closes #75)
- Add content size limits to prevent DoS attacks (closes #78)
- Add comprehensive tests for new endpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 01:03:58 +01:00

605 lines
18 KiB
Python

"""Tests for server module and MCP tools."""
import json
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
class TestHealthCheck:
"""Tests for health check endpoint."""
@pytest.mark.asyncio
async def test_health_check_healthy(self):
"""Test health check when healthy."""
import server
# Create a proper async context manager mock
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(return_value=1)
mock_db = MagicMock()
mock_db._pool = MagicMock()
# Make acquire an async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_conn
mock_cm.__aexit__.return_value = None
mock_db.acquire.return_value = mock_cm
server._database = mock_db
result = await server.health_check()
assert result["status"] == "healthy"
assert result["service"] == "knowledge-base"
assert result["database"] == "connected"
@pytest.mark.asyncio
async def test_health_check_no_database(self):
"""Test health check without database."""
import server
server._database = None
result = await server.health_check()
assert result["database"] == "not initialized"
class TestSearchKnowledgeTool:
"""Tests for search_knowledge MCP tool."""
@pytest.mark.asyncio
async def test_search_success(self):
"""Test successful search."""
import server
from models import SearchResponse, SearchResult
mock_search = MagicMock()
mock_search.search = AsyncMock(
return_value=SearchResponse(
query="test query",
search_type="hybrid",
results=[
SearchResult(
id="id-1",
content="Test content",
score=0.95,
source_path="/test/file.py",
chunk_type="code",
collection="default",
)
],
total_results=1,
search_time_ms=10.5,
)
)
server._search = mock_search
# Call the wrapped function via .fn
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test query",
search_type="hybrid",
collection=None,
limit=10,
threshold=0.7,
file_types=None,
)
assert result["success"] is True
assert len(result["results"]) == 1
assert result["results"][0]["score"] == 0.95
@pytest.mark.asyncio
async def test_search_invalid_type(self):
"""Test search with invalid search type."""
import server
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test",
search_type="invalid",
)
assert result["success"] is False
assert "Invalid search type" in result["error"]
@pytest.mark.asyncio
async def test_search_invalid_file_type(self):
"""Test search with invalid file type."""
import server
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test",
search_type="hybrid",
collection=None,
limit=10,
threshold=0.7,
file_types=["invalid_type"],
)
assert result["success"] is False
assert "Invalid file type" in result["error"]
class TestIngestContentTool:
"""Tests for ingest_content MCP tool."""
@pytest.mark.asyncio
async def test_ingest_success(self):
"""Test successful ingestion."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.ingest = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=3,
embeddings_generated=3,
source_path="/test/file.py",
collection="default",
chunk_ids=["id-1", "id-2", "id-3"],
)
)
server._collections = mock_collections
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="def hello(): pass",
source_path="/test/file.py",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True
assert result["chunks_created"] == 3
assert len(result["chunk_ids"]) == 3
@pytest.mark.asyncio
async def test_ingest_invalid_chunk_type(self):
"""Test ingest with invalid chunk type."""
import server
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="test content",
chunk_type="invalid",
)
assert result["success"] is False
assert "Invalid chunk type" in result["error"]
@pytest.mark.asyncio
async def test_ingest_invalid_file_type(self):
"""Test ingest with invalid file type."""
import server
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="test content",
source_path=None,
collection="default",
chunk_type="text",
file_type="invalid",
metadata=None,
)
assert result["success"] is False
assert "Invalid file type" in result["error"]
class TestDeleteContentTool:
"""Tests for delete_content MCP tool."""
@pytest.mark.asyncio
async def test_delete_success(self):
"""Test successful deletion."""
import server
from models import DeleteResult
mock_collections = MagicMock()
mock_collections.delete = AsyncMock(
return_value=DeleteResult(
success=True,
chunks_deleted=5,
)
)
server._collections = mock_collections
result = await server.delete_content.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
collection=None,
chunk_ids=None,
)
assert result["success"] is True
assert result["chunks_deleted"] == 5
class TestListCollectionsTool:
"""Tests for list_collections MCP tool."""
@pytest.mark.asyncio
async def test_list_collections_success(self):
"""Test listing collections."""
import server
from models import CollectionInfo, ListCollectionsResponse
mock_collections = MagicMock()
mock_collections.list_collections = AsyncMock(
return_value=ListCollectionsResponse(
project_id="proj-123",
collections=[
CollectionInfo(
name="collection-1",
project_id="proj-123",
chunk_count=100,
total_tokens=50000,
file_types=["python"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
],
total_collections=1,
)
)
server._collections = mock_collections
result = await server.list_collections.fn(
project_id="proj-123",
agent_id="agent-456",
)
assert result["success"] is True
assert result["total_collections"] == 1
assert len(result["collections"]) == 1
class TestGetCollectionStatsTool:
"""Tests for get_collection_stats MCP tool."""
@pytest.mark.asyncio
async def test_get_stats_success(self):
"""Test getting collection stats."""
import server
from models import CollectionStats
mock_collections = MagicMock()
mock_collections.get_collection_stats = AsyncMock(
return_value=CollectionStats(
collection="test-collection",
project_id="proj-123",
chunk_count=100,
unique_sources=10,
total_tokens=50000,
avg_chunk_size=500.0,
chunk_types={"code": 60, "text": 40},
file_types={"python": 50, "javascript": 10},
)
)
server._collections = mock_collections
result = await server.get_collection_stats.fn(
project_id="proj-123",
agent_id="agent-456",
collection="test-collection",
)
assert result["success"] is True
assert result["chunk_count"] == 100
assert result["unique_sources"] == 10
class TestUpdateDocumentTool:
"""Tests for update_document MCP tool."""
@pytest.mark.asyncio
async def test_update_success(self):
"""Test updating a document."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.update_document = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=2,
embeddings_generated=2,
source_path="/test/file.py",
collection="default",
chunk_ids=["id-1", "id-2"],
)
)
server._collections = mock_collections
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="def updated(): pass",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True
assert result["chunks_created"] == 2
@pytest.mark.asyncio
async def test_update_invalid_chunk_type(self):
"""Test update with invalid chunk type."""
import server
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="test",
chunk_type="invalid",
)
assert result["success"] is False
assert "Invalid chunk type" in result["error"]
class TestMCPToolsEndpoint:
"""Tests for /mcp/tools endpoint."""
def test_list_mcp_tools(self):
"""Test listing available MCP tools."""
import server
client = TestClient(server.app)
response = client.get("/mcp/tools")
assert response.status_code == 200
data = response.json()
assert "tools" in data
assert len(data["tools"]) == 6 # 6 tools registered
tool_names = [t["name"] for t in data["tools"]]
assert "search_knowledge" in tool_names
assert "ingest_content" in tool_names
assert "delete_content" in tool_names
assert "list_collections" in tool_names
assert "get_collection_stats" in tool_names
assert "update_document" in tool_names
def test_tool_has_schema(self):
"""Test that each tool has input schema."""
import server
client = TestClient(server.app)
response = client.get("/mcp/tools")
data = response.json()
for tool in data["tools"]:
assert "inputSchema" in tool
assert "type" in tool["inputSchema"]
assert tool["inputSchema"]["type"] == "object"
class TestMCPRPCEndpoint:
"""Tests for /mcp JSON-RPC endpoint."""
def test_valid_jsonrpc_request(self):
"""Test valid JSON-RPC request."""
import server
from models import SearchResponse, SearchResult
mock_search = MagicMock()
mock_search.search = AsyncMock(
return_value=SearchResponse(
query="test",
search_type="hybrid",
results=[
SearchResult(
id="id-1",
content="Test",
score=0.9,
source_path="/test.py",
chunk_type="code",
collection="default",
)
],
total_results=1,
search_time_ms=5.0,
)
)
server._search = mock_search
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "search_knowledge",
"params": {
"project_id": "proj-123",
"agent_id": "agent-456",
"query": "test",
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == 1
assert "result" in data
assert data["result"]["success"] is True
def test_invalid_jsonrpc_version(self):
"""Test request with invalid JSON-RPC version."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "1.0",
"method": "search_knowledge",
"params": {},
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
assert "jsonrpc must be '2.0'" in data["error"]["message"]
def test_missing_method(self):
"""Test request without method."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"params": {},
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
assert "method is required" in data["error"]["message"]
def test_unknown_method(self):
"""Test request with unknown method."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "unknown_method",
"params": {},
"id": 1,
},
)
assert response.status_code == 404
data = response.json()
assert data["error"]["code"] == -32601
assert "Method not found" in data["error"]["message"]
def test_invalid_params(self):
"""Test request with invalid params."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "search_knowledge",
"params": {"invalid_param": "value"}, # Missing required params
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32602
class TestContentSizeLimits:
"""Tests for content size validation."""
@pytest.mark.asyncio
async def test_ingest_rejects_oversized_content(self):
"""Test that ingest rejects content exceeding size limit."""
import server
from config import get_settings
settings = get_settings()
# Create content larger than max size
oversized_content = "x" * (settings.max_document_size + 1)
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content=oversized_content,
chunk_type="text",
)
assert result["success"] is False
assert "exceeds maximum" in result["error"]
@pytest.mark.asyncio
async def test_update_rejects_oversized_content(self):
"""Test that update rejects content exceeding size limit."""
import server
from config import get_settings
settings = get_settings()
oversized_content = "x" * (settings.max_document_size + 1)
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test.py",
content=oversized_content,
chunk_type="text",
)
assert result["success"] is False
assert "exceeds maximum" in result["error"]
@pytest.mark.asyncio
async def test_ingest_accepts_valid_size_content(self):
"""Test that ingest accepts content within size limit."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.ingest = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=1,
embeddings_generated=1,
source_path="/test.py",
collection="default",
chunk_ids=["id-1"],
)
)
server._collections = mock_collections
# Small content that's within limits
# Pass all parameters to avoid Field default resolution issues
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="def hello(): pass",
source_path="/test.py",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True