Files
syndarix/mcp-servers/knowledge-base/tests/test_server.py
Felipe Cardoso 4154dd5268 feat: enhance database transactions, add Makefiles, and improve Docker setup
- Refactored database batch operations to ensure transaction atomicity and simplify nested structure.
- Added `Makefile` for `knowledge-base` and `llm-gateway` modules to streamline development workflows.
- Simplified `Dockerfile` for `llm-gateway` by removing multi-stage builds and optimizing dependencies.
- Improved code readability in `collection_manager` and `failover` modules with refined logic.
- Minor fixes in `test_server` and Redis health check handling for better diagnostics.
2026-01-05 00:49:19 +01:00

655 lines
20 KiB
Python

"""Tests for server module and MCP tools."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient
class TestHealthCheck:
"""Tests for health check endpoint."""
@pytest.mark.asyncio
async def test_health_check_healthy(self):
"""Test health check when all dependencies are connected."""
import server
# Create a proper async context manager mock for database
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(return_value=1)
mock_db = MagicMock()
mock_db._pool = MagicMock()
# Make acquire an async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_conn
mock_cm.__aexit__.return_value = None
mock_db.acquire.return_value = mock_cm
# Mock Redis
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(return_value=True)
# Mock HTTP client for LLM Gateway
mock_http_response = AsyncMock()
mock_http_response.status_code = 200
mock_http_client = AsyncMock()
mock_http_client.get = AsyncMock(return_value=mock_http_response)
# Mock embeddings with Redis and HTTP client
mock_embeddings = MagicMock()
mock_embeddings._redis = mock_redis
mock_embeddings._http_client = mock_http_client
server._database = mock_db
server._embeddings = mock_embeddings
result = await server.health_check()
assert result["status"] == "healthy"
assert result["service"] == "knowledge-base"
assert result["dependencies"]["database"] == "connected"
assert result["dependencies"]["redis"] == "connected"
assert result["dependencies"]["llm_gateway"] == "connected"
@pytest.mark.asyncio
async def test_health_check_no_database(self):
"""Test health check without database - should be unhealthy."""
import server
server._database = None
server._embeddings = None
result = await server.health_check()
assert result["status"] == "unhealthy"
assert result["dependencies"]["database"] == "not initialized"
@pytest.mark.asyncio
async def test_health_check_degraded(self):
"""Test health check with database but no Redis - should be degraded."""
import server
# Create a proper async context manager mock for database
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(return_value=1)
mock_db = MagicMock()
mock_db._pool = MagicMock()
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_conn
mock_cm.__aexit__.return_value = None
mock_db.acquire.return_value = mock_cm
# Mock embeddings without Redis
mock_embeddings = MagicMock()
mock_embeddings._redis = None
mock_embeddings._http_client = None
server._database = mock_db
server._embeddings = mock_embeddings
result = await server.health_check()
assert result["status"] == "degraded"
assert result["dependencies"]["database"] == "connected"
assert result["dependencies"]["redis"] == "not initialized"
class TestSearchKnowledgeTool:
"""Tests for search_knowledge MCP tool."""
@pytest.mark.asyncio
async def test_search_success(self):
"""Test successful search."""
import server
from models import SearchResponse, SearchResult
mock_search = MagicMock()
mock_search.search = AsyncMock(
return_value=SearchResponse(
query="test query",
search_type="hybrid",
results=[
SearchResult(
id="id-1",
content="Test content",
score=0.95,
source_path="/test/file.py",
chunk_type="code",
collection="default",
)
],
total_results=1,
search_time_ms=10.5,
)
)
server._search = mock_search
# Call the wrapped function via .fn
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test query",
search_type="hybrid",
collection=None,
limit=10,
threshold=0.7,
file_types=None,
)
assert result["success"] is True
assert len(result["results"]) == 1
assert result["results"][0]["score"] == 0.95
@pytest.mark.asyncio
async def test_search_invalid_type(self):
"""Test search with invalid search type."""
import server
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test",
search_type="invalid",
)
assert result["success"] is False
assert "Invalid search type" in result["error"]
@pytest.mark.asyncio
async def test_search_invalid_file_type(self):
"""Test search with invalid file type."""
import server
result = await server.search_knowledge.fn(
project_id="proj-123",
agent_id="agent-456",
query="test",
search_type="hybrid",
collection=None,
limit=10,
threshold=0.7,
file_types=["invalid_type"],
)
assert result["success"] is False
assert "Invalid file type" in result["error"]
class TestIngestContentTool:
"""Tests for ingest_content MCP tool."""
@pytest.mark.asyncio
async def test_ingest_success(self):
"""Test successful ingestion."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.ingest = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=3,
embeddings_generated=3,
source_path="/test/file.py",
collection="default",
chunk_ids=["id-1", "id-2", "id-3"],
)
)
server._collections = mock_collections
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="def hello(): pass",
source_path="/test/file.py",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True
assert result["chunks_created"] == 3
assert len(result["chunk_ids"]) == 3
@pytest.mark.asyncio
async def test_ingest_invalid_chunk_type(self):
"""Test ingest with invalid chunk type."""
import server
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="test content",
chunk_type="invalid",
)
assert result["success"] is False
assert "Invalid chunk type" in result["error"]
@pytest.mark.asyncio
async def test_ingest_invalid_file_type(self):
"""Test ingest with invalid file type."""
import server
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="test content",
source_path=None,
collection="default",
chunk_type="text",
file_type="invalid",
metadata=None,
)
assert result["success"] is False
assert "Invalid file type" in result["error"]
class TestDeleteContentTool:
"""Tests for delete_content MCP tool."""
@pytest.mark.asyncio
async def test_delete_success(self):
"""Test successful deletion."""
import server
from models import DeleteResult
mock_collections = MagicMock()
mock_collections.delete = AsyncMock(
return_value=DeleteResult(
success=True,
chunks_deleted=5,
)
)
server._collections = mock_collections
result = await server.delete_content.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
collection=None,
chunk_ids=None,
)
assert result["success"] is True
assert result["chunks_deleted"] == 5
class TestListCollectionsTool:
"""Tests for list_collections MCP tool."""
@pytest.mark.asyncio
async def test_list_collections_success(self):
"""Test listing collections."""
import server
from models import CollectionInfo, ListCollectionsResponse
mock_collections = MagicMock()
mock_collections.list_collections = AsyncMock(
return_value=ListCollectionsResponse(
project_id="proj-123",
collections=[
CollectionInfo(
name="collection-1",
project_id="proj-123",
chunk_count=100,
total_tokens=50000,
file_types=["python"],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
],
total_collections=1,
)
)
server._collections = mock_collections
result = await server.list_collections.fn(
project_id="proj-123",
agent_id="agent-456",
)
assert result["success"] is True
assert result["total_collections"] == 1
assert len(result["collections"]) == 1
class TestGetCollectionStatsTool:
"""Tests for get_collection_stats MCP tool."""
@pytest.mark.asyncio
async def test_get_stats_success(self):
"""Test getting collection stats."""
import server
from models import CollectionStats
mock_collections = MagicMock()
mock_collections.get_collection_stats = AsyncMock(
return_value=CollectionStats(
collection="test-collection",
project_id="proj-123",
chunk_count=100,
unique_sources=10,
total_tokens=50000,
avg_chunk_size=500.0,
chunk_types={"code": 60, "text": 40},
file_types={"python": 50, "javascript": 10},
)
)
server._collections = mock_collections
result = await server.get_collection_stats.fn(
project_id="proj-123",
agent_id="agent-456",
collection="test-collection",
)
assert result["success"] is True
assert result["chunk_count"] == 100
assert result["unique_sources"] == 10
class TestUpdateDocumentTool:
"""Tests for update_document MCP tool."""
@pytest.mark.asyncio
async def test_update_success(self):
"""Test updating a document."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.update_document = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=2,
embeddings_generated=2,
source_path="/test/file.py",
collection="default",
chunk_ids=["id-1", "id-2"],
)
)
server._collections = mock_collections
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="def updated(): pass",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True
assert result["chunks_created"] == 2
@pytest.mark.asyncio
async def test_update_invalid_chunk_type(self):
"""Test update with invalid chunk type."""
import server
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test/file.py",
content="test",
chunk_type="invalid",
)
assert result["success"] is False
assert "Invalid chunk type" in result["error"]
class TestMCPToolsEndpoint:
"""Tests for /mcp/tools endpoint."""
def test_list_mcp_tools(self):
"""Test listing available MCP tools."""
import server
client = TestClient(server.app)
response = client.get("/mcp/tools")
assert response.status_code == 200
data = response.json()
assert "tools" in data
assert len(data["tools"]) == 6 # 6 tools registered
tool_names = [t["name"] for t in data["tools"]]
assert "search_knowledge" in tool_names
assert "ingest_content" in tool_names
assert "delete_content" in tool_names
assert "list_collections" in tool_names
assert "get_collection_stats" in tool_names
assert "update_document" in tool_names
def test_tool_has_schema(self):
"""Test that each tool has input schema."""
import server
client = TestClient(server.app)
response = client.get("/mcp/tools")
data = response.json()
for tool in data["tools"]:
assert "inputSchema" in tool
assert "type" in tool["inputSchema"]
assert tool["inputSchema"]["type"] == "object"
class TestMCPRPCEndpoint:
"""Tests for /mcp JSON-RPC endpoint."""
def test_valid_jsonrpc_request(self):
"""Test valid JSON-RPC request."""
import server
from models import SearchResponse, SearchResult
mock_search = MagicMock()
mock_search.search = AsyncMock(
return_value=SearchResponse(
query="test",
search_type="hybrid",
results=[
SearchResult(
id="id-1",
content="Test",
score=0.9,
source_path="/test.py",
chunk_type="code",
collection="default",
)
],
total_results=1,
search_time_ms=5.0,
)
)
server._search = mock_search
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "search_knowledge",
"params": {
"project_id": "proj-123",
"agent_id": "agent-456",
"query": "test",
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == 1
assert "result" in data
assert data["result"]["success"] is True
def test_invalid_jsonrpc_version(self):
"""Test request with invalid JSON-RPC version."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "1.0",
"method": "search_knowledge",
"params": {},
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
assert "jsonrpc must be '2.0'" in data["error"]["message"]
def test_missing_method(self):
"""Test request without method."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"params": {},
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
assert "method is required" in data["error"]["message"]
def test_unknown_method(self):
"""Test request with unknown method."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "unknown_method",
"params": {},
"id": 1,
},
)
assert response.status_code == 404
data = response.json()
assert data["error"]["code"] == -32601
assert "Method not found" in data["error"]["message"]
def test_invalid_params(self):
"""Test request with invalid params."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "search_knowledge",
"params": {"invalid_param": "value"}, # Missing required params
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32602
class TestContentSizeLimits:
"""Tests for content size validation."""
@pytest.mark.asyncio
async def test_ingest_rejects_oversized_content(self):
"""Test that ingest rejects content exceeding size limit."""
import server
from config import get_settings
settings = get_settings()
# Create content larger than max size
oversized_content = "x" * (settings.max_document_size + 1)
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content=oversized_content,
chunk_type="text",
)
assert result["success"] is False
assert "exceeds maximum" in result["error"]
@pytest.mark.asyncio
async def test_update_rejects_oversized_content(self):
"""Test that update rejects content exceeding size limit."""
import server
from config import get_settings
settings = get_settings()
oversized_content = "x" * (settings.max_document_size + 1)
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test.py",
content=oversized_content,
chunk_type="text",
)
assert result["success"] is False
assert "exceeds maximum" in result["error"]
@pytest.mark.asyncio
async def test_ingest_accepts_valid_size_content(self):
"""Test that ingest accepts content within size limit."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.ingest = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=1,
embeddings_generated=1,
source_path="/test.py",
collection="default",
chunk_ids=["id-1"],
)
)
server._collections = mock_collections
# Small content that's within limits
# Pass all parameters to avoid Field default resolution issues
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="def hello(): pass",
source_path="/test.py",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True