4 Commits

Author SHA1 Message Date
Felipe Cardoso
cd7a9ccbdf fix(mcp-kb): add transactional batch insert and atomic document update
- Wrap store_embeddings_batch in transaction for all-or-nothing semantics
- Add replace_source_embeddings method for atomic document updates
- Update collection_manager to use transactional replace
- Prevents race conditions and data inconsistency (closes #77)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 01:07:40 +01:00
Felipe Cardoso
953af52d0e fix(mcp-kb): address critical issues from deep review
- Fix SQL HAVING clause bug by using CTE approach (closes #73)
- Add /mcp JSON-RPC 2.0 endpoint for tool execution (closes #74)
- Add /mcp/tools endpoint for tool discovery (closes #75)
- Add content size limits to prevent DoS attacks (closes #78)
- Add comprehensive tests for new endpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 01:03:58 +01:00
Felipe Cardoso
e6e98d4ed1 docs(workflow): enforce stack verification as mandatory step
- Added "Stack Verification" section to CLAUDE.md with detailed steps.
- Updated WORKFLOW.md to mandate running the full stack before marking work as complete.
- Prevents issues where high test coverage masks application startup failures.
2026-01-04 00:58:31 +01:00
Felipe Cardoso
ca5f5e3383 refactor(environment): update virtualenv path to /opt/venv in Docker setup
- Adjusted `docker-compose.dev.yml` to reflect the new venv location.
- Modified entrypoint script and Dockerfile to reference `/opt/venv` for isolated dependencies.
- Improved bind mount setup to prevent venv overwrites during development.
2026-01-04 00:58:24 +01:00
12 changed files with 835 additions and 80 deletions

View File

@@ -83,6 +83,37 @@ docs/
3. **Testing Required**: All code must be tested, aim for >90% coverage
4. **Code Review**: Must pass multi-agent review before merge
5. **No Direct Commits**: Never commit directly to `main` or `dev`
6. **Stack Verification**: ALWAYS run the full stack before considering work done (see below)
### CRITICAL: Stack Verification Before Merge
**This is NON-NEGOTIABLE. A feature with 100% test coverage that crashes on startup is WORTHLESS.**
Before considering ANY issue complete:
```bash
# 1. Start the dev stack
make dev
# 2. Wait for backend to be healthy, check logs
docker compose -f docker-compose.dev.yml logs backend --tail=100
# 3. Start frontend
cd frontend && npm run dev
# 4. Verify both are running without errors
```
**The issue is NOT done if:**
- Backend crashes on startup (import errors, missing dependencies)
- Frontend fails to compile or render
- Health checks fail
- Any error appears in logs
**Why this matters:**
- Tests run in isolation and may pass despite broken imports
- Docker builds cache layers and may hide dependency issues
- A single `ModuleNotFoundError` renders all test coverage meaningless
### Common Commands

View File

@@ -7,7 +7,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONPATH=/app \
UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \
UV_NO_CACHE=1
UV_NO_CACHE=1 \
UV_PROJECT_ENVIRONMENT=/opt/venv \
VIRTUAL_ENV=/opt/venv \
PATH="/opt/venv/bin:$PATH"
# Install system dependencies and uv
RUN apt-get update && \
@@ -20,7 +23,7 @@ RUN apt-get update && \
# Copy dependency files
COPY pyproject.toml uv.lock ./
# Install dependencies using uv (development mode with dev dependencies)
# Install dependencies using uv into /opt/venv (outside /app to survive bind mounts)
RUN uv sync --extra dev --frozen
# Copy application code
@@ -45,7 +48,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONPATH=/app \
UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \
UV_NO_CACHE=1
UV_NO_CACHE=1 \
UV_PROJECT_ENVIRONMENT=/opt/venv \
VIRTUAL_ENV=/opt/venv \
PATH="/opt/venv/bin:$PATH"
# Install system dependencies and uv
RUN apt-get update && \
@@ -58,7 +64,7 @@ RUN apt-get update && \
# Copy dependency files
COPY pyproject.toml uv.lock ./
# Install only production dependencies using uv (no dev dependencies)
# Install only production dependencies using uv into /opt/venv
RUN uv sync --frozen --no-dev
# Copy application code
@@ -67,7 +73,7 @@ COPY entrypoint.sh /usr/local/bin/
RUN chmod +x /usr/local/bin/entrypoint.sh
# Set ownership to non-root user
RUN chown -R appuser:appuser /app
RUN chown -R appuser:appuser /app /opt/venv
# Switch to non-root user
USER appuser
@@ -77,4 +83,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -1,11 +1,11 @@
#!/bin/bash
set -e
# Ensure the project's virtualenv binaries are on PATH so commands like
# 'uvicorn' work even when not prefixed by 'uv run'. This matches how uv
# installs the env into /app/.venv in our containers.
if [ -d "/app/.venv/bin" ]; then
export PATH="/app/.venv/bin:$PATH"
# Ensure the virtualenv binaries are on PATH. Dependencies are installed
# to /opt/venv (not /app/.venv) to survive bind mounts in development.
if [ -d "/opt/venv/bin" ]; then
export PATH="/opt/venv/bin:$PATH"
export VIRTUAL_ENV="/opt/venv"
fi
# Only the backend service should run migrations and init_db

View File

@@ -40,8 +40,7 @@ services:
volumes:
- ./backend:/app
- ./uploads:/app/uploads
# Exclude local .venv from bind mount to use container's .venv
- /app/.venv
# Note: venv is at /opt/venv (not /app/.venv) so bind mount doesn't affect it
ports:
- "8000:8000"
env_file:
@@ -76,7 +75,6 @@ services:
target: development
volumes:
- ./backend:/app
- /app/.venv
env_file:
- .env
environment:
@@ -99,7 +97,6 @@ services:
target: development
volumes:
- ./backend:/app
- /app/.venv
env_file:
- .env
environment:
@@ -122,7 +119,6 @@ services:
target: development
volumes:
- ./backend:/app
- /app/.venv
env_file:
- .env
environment:
@@ -145,7 +141,6 @@ services:
target: development
volumes:
- ./backend:/app
- /app/.venv
env_file:
- .env
environment:

View File

@@ -214,9 +214,9 @@ test(frontend): add unit tests for ProjectDashboard
---
## Phase 2+ Implementation Workflow
## Rigorous Implementation Workflow
**For complex infrastructure issues (Phase 2 MCP, core systems), follow this rigorous process:**
**This workflow applies to ALL feature implementations. Follow this process rigorously:**
### 1. Branch Setup
```bash
@@ -257,12 +257,42 @@ Before closing an issue, perform deep review from multiple angles:
**No stone unturned. No sloppy results. No unreviewed work.**
### 5. Final Validation
### 5. Stack Verification (CRITICAL - NON-NEGOTIABLE)
**ALWAYS run the full stack and verify it boots correctly before considering ANY work done.**
```bash
# Start the full development stack
make dev
# Check backend logs for startup errors
docker compose -f docker-compose.dev.yml logs backend --tail=100
# Start frontend separately
cd frontend && npm run dev
# Check frontend console for errors
```
**A feature is NOT complete if:**
- The stack doesn't boot
- There are import errors in logs
- Health checks fail
- Any component crashes on startup
**This rule exists because:**
- Tests can pass but the application won't start (import errors, missing deps)
- 90% test coverage is worthless if the app crashes on boot
- Docker builds can mask local issues
### 6. Final Validation Checklist
- [ ] All tests pass (unit, integration, E2E)
- [ ] Type checking passes
- [ ] Linting passes
- [ ] **Stack boots successfully** (backend + frontend)
- [ ] **Logs show no errors**
- [ ] Coverage meets threshold (>90% backend, >90% frontend)
- [ ] Documentation updated
- [ ] Coverage meets threshold
- [ ] Issue checklist 100% complete
- [ ] Multi-agent review passed

View File

@@ -265,9 +265,10 @@ class CollectionManager:
metadata: dict[str, Any] | None = None,
) -> IngestResult:
"""
Update a document by replacing existing chunks.
Update a document by atomically replacing existing chunks.
Deletes existing chunks for the source path and ingests new content.
Uses a database transaction to delete existing chunks and insert new ones
atomically, preventing race conditions during concurrent updates.
Args:
project_id: Project ID
@@ -282,26 +283,76 @@ class CollectionManager:
Returns:
Ingest result
"""
# First delete existing chunks for this source
request_metadata = metadata or {}
# Chunk the content
chunks = self.chunker_factory.chunk_content(
content=content,
source_path=source_path,
file_type=file_type,
chunk_type=chunk_type,
metadata=request_metadata,
)
if not chunks:
# No chunks = delete existing and return empty result
await self.database.delete_by_source(
project_id=project_id,
source_path=source_path,
collection=collection,
)
# Then ingest new content
request = IngestRequest(
project_id=project_id,
agent_id=agent_id,
content=content,
return IngestResult(
success=True,
chunks_created=0,
embeddings_generated=0,
source_path=source_path,
collection=collection,
chunk_type=chunk_type,
file_type=file_type,
metadata=metadata or {},
chunk_ids=[],
)
return await self.ingest(request)
# Generate embeddings for new chunks
chunk_texts = [chunk.content for chunk in chunks]
embeddings_list = await self.embeddings.generate_batch(
texts=chunk_texts,
project_id=project_id,
agent_id=agent_id,
)
# Build embeddings data for transactional replace
embeddings_data = []
for chunk, embedding in zip(chunks, embeddings_list, strict=True):
chunk_metadata = {
**request_metadata,
**chunk.metadata,
"token_count": chunk.token_count,
"source_path": chunk.source_path or source_path,
"start_line": chunk.start_line,
"end_line": chunk.end_line,
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
}
embeddings_data.append((
chunk.content,
embedding,
chunk.chunk_type,
chunk_metadata,
))
# Atomically replace old embeddings with new ones
_, chunk_ids = await self.database.replace_source_embeddings(
project_id=project_id,
source_path=source_path,
collection=collection,
embeddings=embeddings_data,
)
return IngestResult(
success=True,
chunks_created=len(chunk_ids),
embeddings_generated=len(embeddings_list),
source_path=source_path,
collection=collection,
chunk_ids=chunk_ids,
)
async def cleanup_expired(self) -> int:
"""

View File

@@ -112,6 +112,20 @@ class Settings(BaseSettings):
description="TTL for embedding records in days (0 = no expiry)",
)
# Content size limits (DoS prevention)
max_document_size: int = Field(
default=10 * 1024 * 1024, # 10 MB
description="Maximum size of a single document in bytes",
)
max_batch_size: int = Field(
default=100,
description="Maximum number of documents in a batch operation",
)
max_batch_total_size: int = Field(
default=50 * 1024 * 1024, # 50 MB
description="Maximum total size of all documents in a batch",
)
model_config = {"env_prefix": "KB_", "env_file": ".env", "extra": "ignore"}

View File

@@ -285,6 +285,8 @@ class DatabaseManager:
try:
async with self.acquire() as conn:
# Wrap in transaction for all-or-nothing batch semantics
async with conn.transaction():
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
source_path = metadata.get("source_path")
@@ -345,8 +347,9 @@ class DatabaseManager:
"""
try:
async with self.acquire() as conn:
# Build query with optional filters
query = """
# Build query with optional filters using CTE to filter by similarity
# We use a CTE to compute similarity once, then filter in outer query
inner_query = """
SELECT
id, project_id, collection, content, embedding,
chunk_type, source_path, start_line, end_line,
@@ -361,18 +364,21 @@ class DatabaseManager:
param_idx = 3
if collection:
query += f" AND collection = ${param_idx}"
inner_query += f" AND collection = ${param_idx}"
params.append(collection)
param_idx += 1
if file_types:
file_type_values = [ft.value for ft in file_types]
query += f" AND file_type = ANY(${param_idx})"
inner_query += f" AND file_type = ANY(${param_idx})"
params.append(file_type_values)
param_idx += 1
query += f"""
HAVING 1 - (embedding <=> $1) >= ${param_idx}
# Wrap in CTE and filter by threshold in outer query
query = f"""
WITH scored AS ({inner_query})
SELECT * FROM scored
WHERE similarity >= ${param_idx}
ORDER BY similarity DESC
LIMIT ${param_idx + 1}
"""
@@ -531,6 +537,96 @@ class DatabaseManager:
cause=e,
)
async def replace_source_embeddings(
self,
project_id: str,
source_path: str,
collection: str,
embeddings: list[tuple[str, list[float], ChunkType, dict[str, Any]]],
) -> tuple[int, list[str]]:
"""
Atomically replace all embeddings for a source path.
Deletes existing embeddings and inserts new ones in a single transaction,
preventing race conditions during document updates.
Args:
project_id: Project ID
source_path: Source file path being updated
collection: Collection name
embeddings: List of (content, embedding, chunk_type, metadata)
Returns:
Tuple of (deleted_count, new_embedding_ids)
"""
expires_at = None
if self._settings.embedding_ttl_days > 0:
expires_at = datetime.now(UTC) + timedelta(
days=self._settings.embedding_ttl_days
)
try:
async with self.acquire() as conn:
# Use transaction for atomic replace
async with conn.transaction():
# First, delete existing embeddings for this source
delete_result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2 AND collection = $3
""",
project_id,
source_path,
collection,
)
deleted_count = int(delete_result.split()[-1])
# Then insert new embeddings
new_ids = []
for content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
new_ids.append(str(embedding_id))
logger.info(
f"Replaced source {source_path}: deleted {deleted_count}, "
f"inserted {len(new_ids)} embeddings"
)
return deleted_count, new_ids
except asyncpg.PostgresError as e:
logger.error(f"Replace source error: {e}")
raise DatabaseQueryError(
message=f"Failed to replace source embeddings: {e}",
cause=e,
)
async def delete_collection(
self,
project_id: str,

View File

@@ -5,11 +5,13 @@ Provides RAG capabilities with pgvector for semantic search,
intelligent chunking, and collection management.
"""
import inspect
import logging
from contextlib import asynccontextmanager
from typing import Any
from typing import Any, get_type_hints
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastmcp import FastMCP
from pydantic import Field
@@ -116,6 +118,259 @@ async def health_check() -> dict[str, Any]:
return status
# Tool registry for JSON-RPC
_tool_registry: dict[str, Any] = {}
def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
"""Convert Python type annotation to JSON Schema."""
type_name = getattr(python_type, "__name__", str(python_type))
if python_type is str or type_name == "str":
return {"type": "string"}
elif python_type is int or type_name == "int":
return {"type": "integer"}
elif python_type is float or type_name == "float":
return {"type": "number"}
elif python_type is bool or type_name == "bool":
return {"type": "boolean"}
elif type_name == "NoneType":
return {"type": "null"}
elif hasattr(python_type, "__origin__"):
origin = python_type.__origin__
args = getattr(python_type, "__args__", ())
if origin is list:
item_type = args[0] if args else Any
return {"type": "array", "items": _python_type_to_json_schema(item_type)}
elif origin is dict:
return {"type": "object"}
elif origin is type(None) or str(origin) == "typing.Union":
# Handle Optional types (Union with None)
non_none_args = [a for a in args if a is not type(None)]
if len(non_none_args) == 1:
schema = _python_type_to_json_schema(non_none_args[0])
schema["nullable"] = True
return schema
return {"type": "object"}
return {"type": "object"}
def _get_tool_schema(func: Any) -> dict[str, Any]:
"""Extract JSON Schema from a tool function."""
sig = inspect.signature(func)
hints = get_type_hints(func) if hasattr(func, "__annotations__") else {}
properties: dict[str, Any] = {}
required: list[str] = []
for name, param in sig.parameters.items():
if name in ("self", "cls"):
continue
prop: dict[str, Any] = {}
# Get type from hints
if name in hints:
prop = _python_type_to_json_schema(hints[name])
# Get description and constraints from Field default (FieldInfo object)
default_val = param.default
if hasattr(default_val, "description") and default_val.description:
prop["description"] = default_val.description
if hasattr(default_val, "ge") and default_val.ge is not None:
prop["minimum"] = default_val.ge
if hasattr(default_val, "le") and default_val.le is not None:
prop["maximum"] = default_val.le
# Handle Field default value (check for PydanticUndefined)
if hasattr(default_val, "default"):
field_default = default_val.default
# Check if it's the "required" sentinel (...)
if field_default is not ... and not (
hasattr(field_default, "__class__")
and "PydanticUndefined" in field_default.__class__.__name__
):
prop["default"] = field_default
# Determine if required
if param.default is inspect.Parameter.empty:
required.append(name)
elif hasattr(default_val, "default"):
field_default = default_val.default
# Required if default is ellipsis or PydanticUndefined
if field_default is ... or (
hasattr(field_default, "__class__")
and "PydanticUndefined" in field_default.__class__.__name__
):
required.append(name)
properties[name] = prop
return {
"type": "object",
"properties": properties,
"required": required,
}
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None:
"""Register a tool in the registry.
Handles both raw functions and FastMCP FunctionTool objects.
"""
# Extract the underlying function from FastMCP FunctionTool if needed
if hasattr(tool_or_func, "fn"):
func = tool_or_func.fn
# Use FunctionTool's description if available
if not description and hasattr(tool_or_func, "description") and tool_or_func.description:
description = tool_or_func.description
else:
func = tool_or_func
_tool_registry[name] = {
"func": func,
"description": description or (func.__doc__ or "").strip(),
"schema": _get_tool_schema(func),
}
@app.get("/mcp/tools")
async def list_mcp_tools() -> dict[str, Any]:
"""
Return list of available MCP tools with their schemas.
This endpoint enables tool discovery for the backend MCP client.
"""
tools = []
for name, info in _tool_registry.items():
tools.append({
"name": name,
"description": info["description"],
"inputSchema": info["schema"],
})
return {"tools": tools}
@app.post("/mcp")
async def mcp_rpc(request: Request) -> JSONResponse:
"""
JSON-RPC 2.0 endpoint for MCP tool execution.
Request format:
{
"jsonrpc": "2.0",
"method": "<tool_name>",
"params": {...},
"id": <request_id>
}
Response format:
{
"jsonrpc": "2.0",
"result": {...},
"id": <request_id>
}
"""
try:
body = await request.json()
except Exception as e:
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32700, "message": f"Parse error: {e}"},
"id": None,
},
)
# Validate JSON-RPC structure
jsonrpc = body.get("jsonrpc")
method = body.get("method")
params = body.get("params", {})
request_id = body.get("id")
if jsonrpc != "2.0":
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"},
"id": request_id,
},
)
if not method:
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: method is required"},
"id": request_id,
},
)
# Look up tool
tool_info = _tool_registry.get(method)
if not tool_info:
return JSONResponse(
status_code=404,
content={
"jsonrpc": "2.0",
"error": {"code": -32601, "message": f"Method not found: {method}"},
"id": request_id,
},
)
# Execute tool
try:
func = tool_info["func"]
# Resolve Field defaults for missing parameters
sig = inspect.signature(func)
resolved_params = dict(params)
for name, param in sig.parameters.items():
if name not in resolved_params:
default_val = param.default
# Check if it's a FieldInfo with a default value
if hasattr(default_val, "default"):
field_default = default_val.default
# Only use if it has an actual default (not required)
if field_default is not ... and not (
hasattr(field_default, "__class__")
and "PydanticUndefined" in field_default.__class__.__name__
):
resolved_params[name] = field_default
result = await func(**resolved_params)
return JSONResponse(
content={
"jsonrpc": "2.0",
"result": result,
"id": request_id,
}
)
except TypeError as e:
return JSONResponse(
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32602, "message": f"Invalid params: {e}"},
"id": request_id,
},
)
except Exception as e:
logger.error(f"Tool execution error: {e}")
return JSONResponse(
status_code=500,
content={
"jsonrpc": "2.0",
"error": {"code": -32000, "message": f"Server error: {e}"},
"id": request_id,
},
)
# MCP Tools
@@ -261,6 +516,15 @@ async def ingest_content(
the LLM Gateway, and stored in pgvector for search.
"""
try:
# Validate content size to prevent DoS
settings = get_settings()
content_size = len(content.encode("utf-8"))
if content_size > settings.max_document_size:
return {
"success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
}
# Parse chunk type
try:
chunk_type_enum = ChunkType(chunk_type.lower())
@@ -492,6 +756,15 @@ async def update_document(
Replaces all existing chunks for the source path with new content.
"""
try:
# Validate content size to prevent DoS
settings = get_settings()
content_size = len(content.encode("utf-8"))
if content_size > settings.max_document_size:
return {
"success": False,
"error": f"Content size ({content_size} bytes) exceeds maximum allowed ({settings.max_document_size} bytes)",
}
# Parse chunk type
try:
chunk_type_enum = ChunkType(chunk_type.lower())
@@ -550,6 +823,16 @@ async def update_document(
}
# Register tools in the JSON-RPC registry
# This must happen after tool functions are defined
_register_tool("search_knowledge", search_knowledge)
_register_tool("ingest_content", ingest_content)
_register_tool("delete_content", delete_content)
_register_tool("list_collections", list_collections)
_register_tool("get_collection_stats", get_collection_stats)
_register_tool("update_document", update_document)
def main() -> None:
"""Run the server."""
import uvicorn

View File

@@ -61,6 +61,7 @@ def mock_database():
mock_db.delete_by_source = AsyncMock(return_value=1)
mock_db.delete_collection = AsyncMock(return_value=5)
mock_db.delete_by_ids = AsyncMock(return_value=2)
mock_db.replace_source_embeddings = AsyncMock(return_value=(1, ["new-id-1"]))
mock_db.list_collections = AsyncMock(return_value=[])
mock_db.get_collection_stats = AsyncMock()
mock_db.cleanup_expired = AsyncMock(return_value=0)

View File

@@ -192,7 +192,7 @@ class TestCollectionManager:
@pytest.mark.asyncio
async def test_update_document(self, collection_manager):
"""Test updating a document."""
"""Test updating a document with atomic replace."""
result = await collection_manager.update_document(
project_id="proj-123",
agent_id="agent-456",
@@ -201,9 +201,10 @@ class TestCollectionManager:
collection="default",
)
# Should delete first, then ingest
collection_manager._database.delete_by_source.assert_called_once()
# Should use atomic replace (delete + insert in transaction)
collection_manager._database.replace_source_embeddings.assert_called_once()
assert result.success is True
assert len(result.chunk_ids) == 1
@pytest.mark.asyncio
async def test_cleanup_expired(self, collection_manager):

View File

@@ -1,9 +1,11 @@
"""Tests for server module and MCP tools."""
import json
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
class TestHealthCheck:
@@ -355,3 +357,248 @@ class TestUpdateDocumentTool:
assert result["success"] is False
assert "Invalid chunk type" in result["error"]
class TestMCPToolsEndpoint:
"""Tests for /mcp/tools endpoint."""
def test_list_mcp_tools(self):
"""Test listing available MCP tools."""
import server
client = TestClient(server.app)
response = client.get("/mcp/tools")
assert response.status_code == 200
data = response.json()
assert "tools" in data
assert len(data["tools"]) == 6 # 6 tools registered
tool_names = [t["name"] for t in data["tools"]]
assert "search_knowledge" in tool_names
assert "ingest_content" in tool_names
assert "delete_content" in tool_names
assert "list_collections" in tool_names
assert "get_collection_stats" in tool_names
assert "update_document" in tool_names
def test_tool_has_schema(self):
"""Test that each tool has input schema."""
import server
client = TestClient(server.app)
response = client.get("/mcp/tools")
data = response.json()
for tool in data["tools"]:
assert "inputSchema" in tool
assert "type" in tool["inputSchema"]
assert tool["inputSchema"]["type"] == "object"
class TestMCPRPCEndpoint:
"""Tests for /mcp JSON-RPC endpoint."""
def test_valid_jsonrpc_request(self):
"""Test valid JSON-RPC request."""
import server
from models import SearchResponse, SearchResult
mock_search = MagicMock()
mock_search.search = AsyncMock(
return_value=SearchResponse(
query="test",
search_type="hybrid",
results=[
SearchResult(
id="id-1",
content="Test",
score=0.9,
source_path="/test.py",
chunk_type="code",
collection="default",
)
],
total_results=1,
search_time_ms=5.0,
)
)
server._search = mock_search
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "search_knowledge",
"params": {
"project_id": "proj-123",
"agent_id": "agent-456",
"query": "test",
},
"id": 1,
},
)
assert response.status_code == 200
data = response.json()
assert data["jsonrpc"] == "2.0"
assert data["id"] == 1
assert "result" in data
assert data["result"]["success"] is True
def test_invalid_jsonrpc_version(self):
"""Test request with invalid JSON-RPC version."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "1.0",
"method": "search_knowledge",
"params": {},
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
assert "jsonrpc must be '2.0'" in data["error"]["message"]
def test_missing_method(self):
"""Test request without method."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"params": {},
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
assert "method is required" in data["error"]["message"]
def test_unknown_method(self):
"""Test request with unknown method."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "unknown_method",
"params": {},
"id": 1,
},
)
assert response.status_code == 404
data = response.json()
assert data["error"]["code"] == -32601
assert "Method not found" in data["error"]["message"]
def test_invalid_params(self):
"""Test request with invalid params."""
import server
client = TestClient(server.app)
response = client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "search_knowledge",
"params": {"invalid_param": "value"}, # Missing required params
"id": 1,
},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32602
class TestContentSizeLimits:
"""Tests for content size validation."""
@pytest.mark.asyncio
async def test_ingest_rejects_oversized_content(self):
"""Test that ingest rejects content exceeding size limit."""
import server
from config import get_settings
settings = get_settings()
# Create content larger than max size
oversized_content = "x" * (settings.max_document_size + 1)
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content=oversized_content,
chunk_type="text",
)
assert result["success"] is False
assert "exceeds maximum" in result["error"]
@pytest.mark.asyncio
async def test_update_rejects_oversized_content(self):
"""Test that update rejects content exceeding size limit."""
import server
from config import get_settings
settings = get_settings()
oversized_content = "x" * (settings.max_document_size + 1)
result = await server.update_document.fn(
project_id="proj-123",
agent_id="agent-456",
source_path="/test.py",
content=oversized_content,
chunk_type="text",
)
assert result["success"] is False
assert "exceeds maximum" in result["error"]
@pytest.mark.asyncio
async def test_ingest_accepts_valid_size_content(self):
"""Test that ingest accepts content within size limit."""
import server
from models import IngestResult
mock_collections = MagicMock()
mock_collections.ingest = AsyncMock(
return_value=IngestResult(
success=True,
chunks_created=1,
embeddings_generated=1,
source_path="/test.py",
collection="default",
chunk_ids=["id-1"],
)
)
server._collections = mock_collections
# Small content that's within limits
# Pass all parameters to avoid Field default resolution issues
result = await server.ingest_content.fn(
project_id="proj-123",
agent_id="agent-456",
content="def hello(): pass",
source_path="/test.py",
collection="default",
chunk_type="text",
file_type=None,
metadata=None,
)
assert result["success"] is True