forked from cardosofelipe/fast-next-template
- Fix SQL HAVING clause bug by using CTE approach (closes #73) - Add /mcp JSON-RPC 2.0 endpoint for tool execution (closes #74) - Add /mcp/tools endpoint for tool discovery (closes #75) - Add content size limits to prevent DoS attacks (closes #78) - Add comprehensive tests for new endpoints 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
605 lines
18 KiB
Python
605 lines
18 KiB
Python
"""Tests for server module and MCP tools."""
|
|
|
|
import json
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
class TestHealthCheck:
|
|
"""Tests for health check endpoint."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_healthy(self):
|
|
"""Test health check when healthy."""
|
|
import server
|
|
|
|
# Create a proper async context manager mock
|
|
mock_conn = AsyncMock()
|
|
mock_conn.fetchval = AsyncMock(return_value=1)
|
|
|
|
mock_db = MagicMock()
|
|
mock_db._pool = MagicMock()
|
|
|
|
# Make acquire an async context manager
|
|
mock_cm = AsyncMock()
|
|
mock_cm.__aenter__.return_value = mock_conn
|
|
mock_cm.__aexit__.return_value = None
|
|
mock_db.acquire.return_value = mock_cm
|
|
|
|
server._database = mock_db
|
|
|
|
result = await server.health_check()
|
|
|
|
assert result["status"] == "healthy"
|
|
assert result["service"] == "knowledge-base"
|
|
assert result["database"] == "connected"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_no_database(self):
|
|
"""Test health check without database."""
|
|
import server
|
|
|
|
server._database = None
|
|
|
|
result = await server.health_check()
|
|
|
|
assert result["database"] == "not initialized"
|
|
|
|
|
|
class TestSearchKnowledgeTool:
|
|
"""Tests for search_knowledge MCP tool."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_success(self):
|
|
"""Test successful search."""
|
|
import server
|
|
from models import SearchResponse, SearchResult
|
|
|
|
mock_search = MagicMock()
|
|
mock_search.search = AsyncMock(
|
|
return_value=SearchResponse(
|
|
query="test query",
|
|
search_type="hybrid",
|
|
results=[
|
|
SearchResult(
|
|
id="id-1",
|
|
content="Test content",
|
|
score=0.95,
|
|
source_path="/test/file.py",
|
|
chunk_type="code",
|
|
collection="default",
|
|
)
|
|
],
|
|
total_results=1,
|
|
search_time_ms=10.5,
|
|
)
|
|
)
|
|
server._search = mock_search
|
|
|
|
# Call the wrapped function via .fn
|
|
result = await server.search_knowledge.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test query",
|
|
search_type="hybrid",
|
|
collection=None,
|
|
limit=10,
|
|
threshold=0.7,
|
|
file_types=None,
|
|
)
|
|
|
|
assert result["success"] is True
|
|
assert len(result["results"]) == 1
|
|
assert result["results"][0]["score"] == 0.95
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_invalid_type(self):
|
|
"""Test search with invalid search type."""
|
|
import server
|
|
|
|
result = await server.search_knowledge.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test",
|
|
search_type="invalid",
|
|
)
|
|
|
|
assert result["success"] is False
|
|
assert "Invalid search type" in result["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_invalid_file_type(self):
|
|
"""Test search with invalid file type."""
|
|
import server
|
|
|
|
result = await server.search_knowledge.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
query="test",
|
|
search_type="hybrid",
|
|
collection=None,
|
|
limit=10,
|
|
threshold=0.7,
|
|
file_types=["invalid_type"],
|
|
)
|
|
|
|
assert result["success"] is False
|
|
assert "Invalid file type" in result["error"]
|
|
|
|
|
|
class TestIngestContentTool:
|
|
"""Tests for ingest_content MCP tool."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_success(self):
|
|
"""Test successful ingestion."""
|
|
import server
|
|
from models import IngestResult
|
|
|
|
mock_collections = MagicMock()
|
|
mock_collections.ingest = AsyncMock(
|
|
return_value=IngestResult(
|
|
success=True,
|
|
chunks_created=3,
|
|
embeddings_generated=3,
|
|
source_path="/test/file.py",
|
|
collection="default",
|
|
chunk_ids=["id-1", "id-2", "id-3"],
|
|
)
|
|
)
|
|
server._collections = mock_collections
|
|
|
|
result = await server.ingest_content.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
content="def hello(): pass",
|
|
source_path="/test/file.py",
|
|
collection="default",
|
|
chunk_type="text",
|
|
file_type=None,
|
|
metadata=None,
|
|
)
|
|
|
|
assert result["success"] is True
|
|
assert result["chunks_created"] == 3
|
|
assert len(result["chunk_ids"]) == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_invalid_chunk_type(self):
|
|
"""Test ingest with invalid chunk type."""
|
|
import server
|
|
|
|
result = await server.ingest_content.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
content="test content",
|
|
chunk_type="invalid",
|
|
)
|
|
|
|
assert result["success"] is False
|
|
assert "Invalid chunk type" in result["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_invalid_file_type(self):
|
|
"""Test ingest with invalid file type."""
|
|
import server
|
|
|
|
result = await server.ingest_content.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
content="test content",
|
|
source_path=None,
|
|
collection="default",
|
|
chunk_type="text",
|
|
file_type="invalid",
|
|
metadata=None,
|
|
)
|
|
|
|
assert result["success"] is False
|
|
assert "Invalid file type" in result["error"]
|
|
|
|
|
|
class TestDeleteContentTool:
|
|
"""Tests for delete_content MCP tool."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_success(self):
|
|
"""Test successful deletion."""
|
|
import server
|
|
from models import DeleteResult
|
|
|
|
mock_collections = MagicMock()
|
|
mock_collections.delete = AsyncMock(
|
|
return_value=DeleteResult(
|
|
success=True,
|
|
chunks_deleted=5,
|
|
)
|
|
)
|
|
server._collections = mock_collections
|
|
|
|
result = await server.delete_content.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
source_path="/test/file.py",
|
|
collection=None,
|
|
chunk_ids=None,
|
|
)
|
|
|
|
assert result["success"] is True
|
|
assert result["chunks_deleted"] == 5
|
|
|
|
|
|
class TestListCollectionsTool:
|
|
"""Tests for list_collections MCP tool."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_collections_success(self):
|
|
"""Test listing collections."""
|
|
import server
|
|
from models import CollectionInfo, ListCollectionsResponse
|
|
|
|
mock_collections = MagicMock()
|
|
mock_collections.list_collections = AsyncMock(
|
|
return_value=ListCollectionsResponse(
|
|
project_id="proj-123",
|
|
collections=[
|
|
CollectionInfo(
|
|
name="collection-1",
|
|
project_id="proj-123",
|
|
chunk_count=100,
|
|
total_tokens=50000,
|
|
file_types=["python"],
|
|
created_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
)
|
|
],
|
|
total_collections=1,
|
|
)
|
|
)
|
|
server._collections = mock_collections
|
|
|
|
result = await server.list_collections.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
)
|
|
|
|
assert result["success"] is True
|
|
assert result["total_collections"] == 1
|
|
assert len(result["collections"]) == 1
|
|
|
|
|
|
class TestGetCollectionStatsTool:
|
|
"""Tests for get_collection_stats MCP tool."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_stats_success(self):
|
|
"""Test getting collection stats."""
|
|
import server
|
|
from models import CollectionStats
|
|
|
|
mock_collections = MagicMock()
|
|
mock_collections.get_collection_stats = AsyncMock(
|
|
return_value=CollectionStats(
|
|
collection="test-collection",
|
|
project_id="proj-123",
|
|
chunk_count=100,
|
|
unique_sources=10,
|
|
total_tokens=50000,
|
|
avg_chunk_size=500.0,
|
|
chunk_types={"code": 60, "text": 40},
|
|
file_types={"python": 50, "javascript": 10},
|
|
)
|
|
)
|
|
server._collections = mock_collections
|
|
|
|
result = await server.get_collection_stats.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
collection="test-collection",
|
|
)
|
|
|
|
assert result["success"] is True
|
|
assert result["chunk_count"] == 100
|
|
assert result["unique_sources"] == 10
|
|
|
|
|
|
class TestUpdateDocumentTool:
|
|
"""Tests for update_document MCP tool."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_success(self):
|
|
"""Test updating a document."""
|
|
import server
|
|
from models import IngestResult
|
|
|
|
mock_collections = MagicMock()
|
|
mock_collections.update_document = AsyncMock(
|
|
return_value=IngestResult(
|
|
success=True,
|
|
chunks_created=2,
|
|
embeddings_generated=2,
|
|
source_path="/test/file.py",
|
|
collection="default",
|
|
chunk_ids=["id-1", "id-2"],
|
|
)
|
|
)
|
|
server._collections = mock_collections
|
|
|
|
result = await server.update_document.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
source_path="/test/file.py",
|
|
content="def updated(): pass",
|
|
collection="default",
|
|
chunk_type="text",
|
|
file_type=None,
|
|
metadata=None,
|
|
)
|
|
|
|
assert result["success"] is True
|
|
assert result["chunks_created"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_invalid_chunk_type(self):
|
|
"""Test update with invalid chunk type."""
|
|
import server
|
|
|
|
result = await server.update_document.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
source_path="/test/file.py",
|
|
content="test",
|
|
chunk_type="invalid",
|
|
)
|
|
|
|
assert result["success"] is False
|
|
assert "Invalid chunk type" in result["error"]
|
|
|
|
|
|
class TestMCPToolsEndpoint:
|
|
"""Tests for /mcp/tools endpoint."""
|
|
|
|
def test_list_mcp_tools(self):
|
|
"""Test listing available MCP tools."""
|
|
import server
|
|
|
|
client = TestClient(server.app)
|
|
response = client.get("/mcp/tools")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "tools" in data
|
|
assert len(data["tools"]) == 6 # 6 tools registered
|
|
|
|
tool_names = [t["name"] for t in data["tools"]]
|
|
assert "search_knowledge" in tool_names
|
|
assert "ingest_content" in tool_names
|
|
assert "delete_content" in tool_names
|
|
assert "list_collections" in tool_names
|
|
assert "get_collection_stats" in tool_names
|
|
assert "update_document" in tool_names
|
|
|
|
def test_tool_has_schema(self):
|
|
"""Test that each tool has input schema."""
|
|
import server
|
|
|
|
client = TestClient(server.app)
|
|
response = client.get("/mcp/tools")
|
|
|
|
data = response.json()
|
|
for tool in data["tools"]:
|
|
assert "inputSchema" in tool
|
|
assert "type" in tool["inputSchema"]
|
|
assert tool["inputSchema"]["type"] == "object"
|
|
|
|
|
|
class TestMCPRPCEndpoint:
|
|
"""Tests for /mcp JSON-RPC endpoint."""
|
|
|
|
def test_valid_jsonrpc_request(self):
|
|
"""Test valid JSON-RPC request."""
|
|
import server
|
|
from models import SearchResponse, SearchResult
|
|
|
|
mock_search = MagicMock()
|
|
mock_search.search = AsyncMock(
|
|
return_value=SearchResponse(
|
|
query="test",
|
|
search_type="hybrid",
|
|
results=[
|
|
SearchResult(
|
|
id="id-1",
|
|
content="Test",
|
|
score=0.9,
|
|
source_path="/test.py",
|
|
chunk_type="code",
|
|
collection="default",
|
|
)
|
|
],
|
|
total_results=1,
|
|
search_time_ms=5.0,
|
|
)
|
|
)
|
|
server._search = mock_search
|
|
|
|
client = TestClient(server.app)
|
|
response = client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "search_knowledge",
|
|
"params": {
|
|
"project_id": "proj-123",
|
|
"agent_id": "agent-456",
|
|
"query": "test",
|
|
},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["jsonrpc"] == "2.0"
|
|
assert data["id"] == 1
|
|
assert "result" in data
|
|
assert data["result"]["success"] is True
|
|
|
|
def test_invalid_jsonrpc_version(self):
|
|
"""Test request with invalid JSON-RPC version."""
|
|
import server
|
|
|
|
client = TestClient(server.app)
|
|
response = client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "1.0",
|
|
"method": "search_knowledge",
|
|
"params": {},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert data["error"]["code"] == -32600
|
|
assert "jsonrpc must be '2.0'" in data["error"]["message"]
|
|
|
|
def test_missing_method(self):
|
|
"""Test request without method."""
|
|
import server
|
|
|
|
client = TestClient(server.app)
|
|
response = client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"params": {},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert data["error"]["code"] == -32600
|
|
assert "method is required" in data["error"]["message"]
|
|
|
|
def test_unknown_method(self):
|
|
"""Test request with unknown method."""
|
|
import server
|
|
|
|
client = TestClient(server.app)
|
|
response = client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "unknown_method",
|
|
"params": {},
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
data = response.json()
|
|
assert data["error"]["code"] == -32601
|
|
assert "Method not found" in data["error"]["message"]
|
|
|
|
def test_invalid_params(self):
|
|
"""Test request with invalid params."""
|
|
import server
|
|
|
|
client = TestClient(server.app)
|
|
response = client.post(
|
|
"/mcp",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"method": "search_knowledge",
|
|
"params": {"invalid_param": "value"}, # Missing required params
|
|
"id": 1,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert data["error"]["code"] == -32602
|
|
|
|
|
|
class TestContentSizeLimits:
|
|
"""Tests for content size validation."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_rejects_oversized_content(self):
|
|
"""Test that ingest rejects content exceeding size limit."""
|
|
import server
|
|
from config import get_settings
|
|
|
|
settings = get_settings()
|
|
# Create content larger than max size
|
|
oversized_content = "x" * (settings.max_document_size + 1)
|
|
|
|
result = await server.ingest_content.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
content=oversized_content,
|
|
chunk_type="text",
|
|
)
|
|
|
|
assert result["success"] is False
|
|
assert "exceeds maximum" in result["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_rejects_oversized_content(self):
|
|
"""Test that update rejects content exceeding size limit."""
|
|
import server
|
|
from config import get_settings
|
|
|
|
settings = get_settings()
|
|
oversized_content = "x" * (settings.max_document_size + 1)
|
|
|
|
result = await server.update_document.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
source_path="/test.py",
|
|
content=oversized_content,
|
|
chunk_type="text",
|
|
)
|
|
|
|
assert result["success"] is False
|
|
assert "exceeds maximum" in result["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_accepts_valid_size_content(self):
|
|
"""Test that ingest accepts content within size limit."""
|
|
import server
|
|
from models import IngestResult
|
|
|
|
mock_collections = MagicMock()
|
|
mock_collections.ingest = AsyncMock(
|
|
return_value=IngestResult(
|
|
success=True,
|
|
chunks_created=1,
|
|
embeddings_generated=1,
|
|
source_path="/test.py",
|
|
collection="default",
|
|
chunk_ids=["id-1"],
|
|
)
|
|
)
|
|
server._collections = mock_collections
|
|
|
|
# Small content that's within limits
|
|
# Pass all parameters to avoid Field default resolution issues
|
|
result = await server.ingest_content.fn(
|
|
project_id="proj-123",
|
|
agent_id="agent-456",
|
|
content="def hello(): pass",
|
|
source_path="/test.py",
|
|
collection="default",
|
|
chunk_type="text",
|
|
file_type=None,
|
|
metadata=None,
|
|
)
|
|
|
|
assert result["success"] is True
|