forked from cardosofelipe/pragma-stack
Compare commits
14 Commits
2bea057fb1
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b149b8a52 | ||
|
|
ad0c06851d | ||
|
|
49359b1416 | ||
|
|
911d950c15 | ||
|
|
b2a3ac60e0 | ||
|
|
dea092e1bb | ||
|
|
4154dd5268 | ||
|
|
db12937495 | ||
|
|
81e1456631 | ||
|
|
58e78d8700 | ||
|
|
5e80139afa | ||
|
|
60ebeaa582 | ||
|
|
758052dcff | ||
|
|
1628eacf2b |
61
.githooks/pre-commit
Executable file
61
.githooks/pre-commit
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
# Pre-commit hook to enforce validation before commits on protected branches
|
||||
# Install: git config core.hooksPath .githooks
|
||||
|
||||
set -e
|
||||
|
||||
# Get the current branch name
|
||||
BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||
|
||||
# Protected branches that require validation
|
||||
PROTECTED_BRANCHES="main dev"
|
||||
|
||||
# Check if we're on a protected branch
|
||||
is_protected() {
|
||||
for branch in $PROTECTED_BRANCHES; do
|
||||
if [ "$BRANCH" = "$branch" ]; then
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
if is_protected; then
|
||||
echo "🔒 Committing to protected branch '$BRANCH' - running validation..."
|
||||
|
||||
# Check if we have backend changes
|
||||
if git diff --cached --name-only | grep -q "^backend/"; then
|
||||
echo "📦 Backend changes detected - running make validate..."
|
||||
cd backend
|
||||
if ! make validate; then
|
||||
echo ""
|
||||
echo "❌ Backend validation failed!"
|
||||
echo " Please fix the issues and try again."
|
||||
echo " Run 'cd backend && make validate' to see errors."
|
||||
exit 1
|
||||
fi
|
||||
cd ..
|
||||
echo "✅ Backend validation passed!"
|
||||
fi
|
||||
|
||||
# Check if we have frontend changes
|
||||
if git diff --cached --name-only | grep -q "^frontend/"; then
|
||||
echo "🎨 Frontend changes detected - running npm run validate..."
|
||||
cd frontend
|
||||
if ! npm run validate 2>/dev/null; then
|
||||
echo ""
|
||||
echo "❌ Frontend validation failed!"
|
||||
echo " Please fix the issues and try again."
|
||||
echo " Run 'cd frontend && npm run validate' to see errors."
|
||||
exit 1
|
||||
fi
|
||||
cd ..
|
||||
echo "✅ Frontend validation passed!"
|
||||
fi
|
||||
|
||||
echo "🎉 All validations passed! Proceeding with commit..."
|
||||
else
|
||||
echo "📝 Committing to feature branch '$BRANCH' - skipping validation (run manually if needed)"
|
||||
fi
|
||||
|
||||
exit 0
|
||||
92
Makefile
92
Makefile
@@ -1,18 +1,31 @@
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "FastAPI + Next.js Full-Stack Template"
|
||||
@echo "Syndarix - AI-Powered Software Consulting Agency"
|
||||
@echo ""
|
||||
@echo "Development:"
|
||||
@echo " make dev - Start backend + db (frontend runs separately)"
|
||||
@echo " make dev - Start backend + db + MCP servers (frontend runs separately)"
|
||||
@echo " make dev-full - Start all services including frontend"
|
||||
@echo " make down - Stop all services"
|
||||
@echo " make logs-dev - Follow dev container logs"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run all tests (backend + MCP servers)"
|
||||
@echo " make test-backend - Run backend tests only"
|
||||
@echo " make test-mcp - Run MCP server tests only"
|
||||
@echo " make test-frontend - Run frontend tests only"
|
||||
@echo " make test-cov - Run all tests with coverage reports"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo ""
|
||||
@echo "Validation:"
|
||||
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
|
||||
@echo " make validate-all - Validate everything including frontend"
|
||||
@echo ""
|
||||
@echo "Database:"
|
||||
@echo " make drop-db - Drop and recreate empty database"
|
||||
@echo " make reset-db - Drop database and apply all migrations"
|
||||
@@ -28,8 +41,10 @@ help:
|
||||
@echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)"
|
||||
@echo ""
|
||||
@echo "Subdirectory commands:"
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
||||
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
|
||||
# ============================================================================
|
||||
# Development
|
||||
@@ -99,3 +114,72 @@ clean:
|
||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||
clean-slate:
|
||||
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
||||
|
||||
# ============================================================================
|
||||
# Testing
|
||||
# ============================================================================
|
||||
|
||||
test: test-backend test-mcp
|
||||
@echo ""
|
||||
@echo "All tests passed!"
|
||||
|
||||
test-backend:
|
||||
@echo "Running backend tests..."
|
||||
@cd backend && IS_TEST=True uv run pytest tests/ -v
|
||||
|
||||
test-mcp:
|
||||
@echo "Running MCP server tests..."
|
||||
@echo ""
|
||||
@echo "=== LLM Gateway ==="
|
||||
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
||||
|
||||
test-frontend:
|
||||
@echo "Running frontend tests..."
|
||||
@cd frontend && npm test
|
||||
|
||||
test-all: test test-frontend
|
||||
@echo ""
|
||||
@echo "All tests (backend + MCP + frontend) passed!"
|
||||
|
||||
test-cov:
|
||||
@echo "Running all tests with coverage..."
|
||||
@echo ""
|
||||
@echo "=== Backend Coverage ==="
|
||||
@cd backend && IS_TEST=True uv run pytest tests/ -v --cov=app --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== LLM Gateway Coverage ==="
|
||||
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base Coverage ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
|
||||
test-integration:
|
||||
@echo "Running MCP integration tests..."
|
||||
@echo "Note: Requires running stack (make dev first)"
|
||||
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# Validation (lint + type-check + test)
|
||||
# ============================================================================
|
||||
|
||||
validate:
|
||||
@echo "Validating backend..."
|
||||
@cd backend && make validate
|
||||
@echo ""
|
||||
@echo "Validating LLM Gateway..."
|
||||
@cd mcp-servers/llm-gateway && make validate
|
||||
@echo ""
|
||||
@echo "Validating Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make validate
|
||||
@echo ""
|
||||
@echo "All validations passed!"
|
||||
|
||||
validate-all: validate
|
||||
@echo ""
|
||||
@echo "Validating frontend..."
|
||||
@cd frontend && npm run validate
|
||||
@echo ""
|
||||
@echo "Full validation passed!"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all test-integration
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -22,6 +22,7 @@ help:
|
||||
@echo " make test-cov - Run pytest with coverage report"
|
||||
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo " make test-all - Run all tests (unit + E2E)"
|
||||
@echo " make check-docker - Check if Docker is available"
|
||||
@echo ""
|
||||
@@ -82,6 +83,15 @@ test-cov:
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||
|
||||
# ============================================================================
|
||||
# Integration Testing (requires running stack: make dev)
|
||||
# ============================================================================
|
||||
|
||||
test-integration:
|
||||
@echo "🧪 Running MCP integration tests..."
|
||||
@echo "Note: Requires running stack (make dev from project root)"
|
||||
@RUN_INTEGRATION_TESTS=true IS_TEST=True PYTHONPATH=. uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# E2E Testing (requires Docker)
|
||||
# ============================================================================
|
||||
|
||||
@@ -5,6 +5,7 @@ from app.api.routes import (
|
||||
agent_types,
|
||||
agents,
|
||||
auth,
|
||||
context,
|
||||
events,
|
||||
issues,
|
||||
mcp,
|
||||
@@ -35,6 +36,9 @@ api_router.include_router(events.router, tags=["Events"])
|
||||
# MCP (Model Context Protocol) router
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
|
||||
|
||||
# Context Management Engine router
|
||||
api_router.include_router(context.router, prefix="/context", tags=["Context"])
|
||||
|
||||
# Syndarix domain routers
|
||||
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
||||
api_router.include_router(
|
||||
|
||||
411
backend/app/api/routes/context.py
Normal file
411
backend/app/api/routes/context.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Context Management API Endpoints.
|
||||
|
||||
Provides REST endpoints for context assembly and optimization
|
||||
for LLM requests using the ContextEngine.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.models.user import User
|
||||
from app.services.context import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
ContextEngine,
|
||||
ContextSettings,
|
||||
create_context_engine,
|
||||
get_context_settings,
|
||||
)
|
||||
from app.services.mcp import MCPClientManager, get_mcp_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton Engine Management
|
||||
# ============================================================================
|
||||
|
||||
_context_engine: ContextEngine | None = None
|
||||
|
||||
|
||||
def _get_or_create_engine(
|
||||
mcp: MCPClientManager,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> ContextEngine:
|
||||
"""Get or create the singleton ContextEngine."""
|
||||
global _context_engine
|
||||
if _context_engine is None:
|
||||
_context_engine = create_context_engine(
|
||||
mcp_manager=mcp,
|
||||
redis=None, # Optional: add Redis caching later
|
||||
settings=settings or get_context_settings(),
|
||||
)
|
||||
logger.info("ContextEngine initialized")
|
||||
else:
|
||||
# Ensure MCP manager is up to date
|
||||
_context_engine.set_mcp_manager(mcp)
|
||||
return _context_engine
|
||||
|
||||
|
||||
async def get_context_engine(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ContextEngine:
|
||||
"""FastAPI dependency to get the ContextEngine."""
|
||||
return _get_or_create_engine(mcp)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ConversationTurn(BaseModel):
|
||||
"""A single conversation turn."""
|
||||
|
||||
role: str = Field(..., description="Role: 'user' or 'assistant'")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""A tool execution result."""
|
||||
|
||||
tool_name: str = Field(..., description="Name of the tool")
|
||||
content: str | dict[str, Any] = Field(..., description="Tool result content")
|
||||
status: str = Field(default="success", description="Execution status")
|
||||
|
||||
|
||||
class AssembleContextRequest(BaseModel):
|
||||
"""Request to assemble context for an LLM request."""
|
||||
|
||||
project_id: str = Field(..., description="Project identifier")
|
||||
agent_id: str = Field(..., description="Agent identifier")
|
||||
query: str = Field(..., description="User's query or current request")
|
||||
model: str = Field(
|
||||
default="claude-3-sonnet",
|
||||
description="Target model name",
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
None,
|
||||
description="Maximum context tokens (uses model default if None)",
|
||||
)
|
||||
system_prompt: str | None = Field(
|
||||
None,
|
||||
description="System prompt/instructions",
|
||||
)
|
||||
task_description: str | None = Field(
|
||||
None,
|
||||
description="Current task description",
|
||||
)
|
||||
knowledge_query: str | None = Field(
|
||||
None,
|
||||
description="Query for knowledge base search",
|
||||
)
|
||||
knowledge_limit: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Max number of knowledge results",
|
||||
)
|
||||
conversation_history: list[ConversationTurn] | None = Field(
|
||||
None,
|
||||
description="Previous conversation turns",
|
||||
)
|
||||
tool_results: list[ToolResult] | None = Field(
|
||||
None,
|
||||
description="Tool execution results to include",
|
||||
)
|
||||
compress: bool = Field(
|
||||
default=True,
|
||||
description="Whether to apply compression",
|
||||
)
|
||||
use_cache: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use caching",
|
||||
)
|
||||
|
||||
|
||||
class AssembledContextResponse(BaseModel):
|
||||
"""Response containing assembled context."""
|
||||
|
||||
content: str = Field(..., description="Assembled context content")
|
||||
total_tokens: int = Field(..., description="Total token count")
|
||||
context_count: int = Field(..., description="Number of context items included")
|
||||
compressed: bool = Field(..., description="Whether compression was applied")
|
||||
budget_used_percent: float = Field(
|
||||
...,
|
||||
description="Percentage of token budget used",
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata",
|
||||
)
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
"""Request to count tokens in content."""
|
||||
|
||||
content: str = Field(..., description="Content to count tokens in")
|
||||
model: str | None = Field(
|
||||
None,
|
||||
description="Model for model-specific tokenization",
|
||||
)
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
"""Response containing token count."""
|
||||
|
||||
token_count: int = Field(..., description="Number of tokens")
|
||||
model: str | None = Field(None, description="Model used for counting")
|
||||
|
||||
|
||||
class BudgetInfoResponse(BaseModel):
|
||||
"""Response containing budget information for a model."""
|
||||
|
||||
model: str = Field(..., description="Model name")
|
||||
total_tokens: int = Field(..., description="Total token budget")
|
||||
system_tokens: int = Field(..., description="Tokens reserved for system")
|
||||
knowledge_tokens: int = Field(..., description="Tokens for knowledge")
|
||||
conversation_tokens: int = Field(..., description="Tokens for conversation")
|
||||
tool_tokens: int = Field(..., description="Tokens for tool results")
|
||||
response_reserve: int = Field(..., description="Tokens reserved for response")
|
||||
|
||||
|
||||
class ContextEngineStatsResponse(BaseModel):
|
||||
"""Response containing engine statistics."""
|
||||
|
||||
cache: dict[str, Any] = Field(..., description="Cache statistics")
|
||||
settings: dict[str, Any] = Field(..., description="Current settings")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(..., description="Health status")
|
||||
mcp_connected: bool = Field(..., description="Whether MCP is connected")
|
||||
cache_enabled: bool = Field(..., description="Whether caching is enabled")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
response_model=HealthResponse,
|
||||
summary="Context Engine Health",
|
||||
description="Check health status of the context engine.",
|
||||
)
|
||||
async def health_check(
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> HealthResponse:
|
||||
"""Check context engine health."""
|
||||
stats = await engine.get_stats()
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
mcp_connected=engine._mcp is not None,
|
||||
cache_enabled=stats.get("settings", {}).get("cache_enabled", False),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assemble",
|
||||
response_model=AssembledContextResponse,
|
||||
summary="Assemble Context",
|
||||
description="Assemble optimized context for an LLM request.",
|
||||
)
|
||||
async def assemble_context(
|
||||
request: AssembleContextRequest,
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> AssembledContextResponse:
|
||||
"""
|
||||
Assemble optimized context for an LLM request.
|
||||
|
||||
This endpoint gathers context from various sources, scores and ranks them,
|
||||
compresses if needed, and formats for the target model.
|
||||
"""
|
||||
logger.info(
|
||||
"Context assembly for project=%s agent=%s by user=%s",
|
||||
request.project_id,
|
||||
request.agent_id,
|
||||
current_user.id,
|
||||
)
|
||||
|
||||
# Convert conversation history to dict format
|
||||
conversation_history = None
|
||||
if request.conversation_history:
|
||||
conversation_history = [
|
||||
{"role": turn.role, "content": turn.content}
|
||||
for turn in request.conversation_history
|
||||
]
|
||||
|
||||
# Convert tool results to dict format
|
||||
tool_results = None
|
||||
if request.tool_results:
|
||||
tool_results = [
|
||||
{
|
||||
"tool_name": tr.tool_name,
|
||||
"content": tr.content,
|
||||
"status": tr.status,
|
||||
}
|
||||
for tr in request.tool_results
|
||||
]
|
||||
|
||||
try:
|
||||
result = await engine.assemble_context(
|
||||
project_id=request.project_id,
|
||||
agent_id=request.agent_id,
|
||||
query=request.query,
|
||||
model=request.model,
|
||||
max_tokens=request.max_tokens,
|
||||
system_prompt=request.system_prompt,
|
||||
task_description=request.task_description,
|
||||
knowledge_query=request.knowledge_query,
|
||||
knowledge_limit=request.knowledge_limit,
|
||||
conversation_history=conversation_history,
|
||||
tool_results=tool_results,
|
||||
compress=request.compress,
|
||||
use_cache=request.use_cache,
|
||||
)
|
||||
|
||||
# Calculate budget usage percentage
|
||||
budget = await engine.get_budget_for_model(request.model, request.max_tokens)
|
||||
budget_used_percent = (result.total_tokens / budget.total) * 100
|
||||
|
||||
# Check if compression was applied (from metadata if available)
|
||||
was_compressed = result.metadata.get("compressed_contexts", 0) > 0
|
||||
|
||||
return AssembledContextResponse(
|
||||
content=result.content,
|
||||
total_tokens=result.total_tokens,
|
||||
context_count=result.context_count,
|
||||
compressed=was_compressed,
|
||||
budget_used_percent=round(budget_used_percent, 2),
|
||||
metadata={
|
||||
"model": request.model,
|
||||
"query": request.query,
|
||||
"knowledge_included": bool(request.knowledge_query),
|
||||
"conversation_turns": len(request.conversation_history or []),
|
||||
"excluded_count": result.excluded_count,
|
||||
"assembly_time_ms": result.assembly_time_ms,
|
||||
},
|
||||
)
|
||||
|
||||
except AssemblyTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=f"Context assembly timed out: {e}",
|
||||
) from e
|
||||
except BudgetExceededError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail=f"Token budget exceeded: {e}",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Context assembly failed")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Context assembly failed: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/count-tokens",
|
||||
response_model=TokenCountResponse,
|
||||
summary="Count Tokens",
|
||||
description="Count tokens in content using the LLM Gateway.",
|
||||
)
|
||||
async def count_tokens(
|
||||
request: TokenCountRequest,
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> TokenCountResponse:
|
||||
"""Count tokens in content."""
|
||||
try:
|
||||
count = await engine.count_tokens(
|
||||
content=request.content,
|
||||
model=request.model,
|
||||
)
|
||||
return TokenCountResponse(
|
||||
token_count=count,
|
||||
model=request.model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Token counting failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token counting failed: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/budget/{model}",
|
||||
response_model=BudgetInfoResponse,
|
||||
summary="Get Token Budget",
|
||||
description="Get token budget allocation for a specific model.",
|
||||
)
|
||||
async def get_budget(
|
||||
model: str,
|
||||
max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None,
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> BudgetInfoResponse:
|
||||
"""Get token budget information for a model."""
|
||||
budget = await engine.get_budget_for_model(model, max_tokens)
|
||||
return BudgetInfoResponse(
|
||||
model=model,
|
||||
total_tokens=budget.total,
|
||||
system_tokens=budget.system,
|
||||
knowledge_tokens=budget.knowledge,
|
||||
conversation_tokens=budget.conversation,
|
||||
tool_tokens=budget.tools,
|
||||
response_reserve=budget.response_reserve,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=ContextEngineStatsResponse,
|
||||
summary="Engine Statistics",
|
||||
description="Get context engine statistics and configuration.",
|
||||
)
|
||||
async def get_stats(
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> ContextEngineStatsResponse:
|
||||
"""Get engine statistics."""
|
||||
stats = await engine.get_stats()
|
||||
return ContextEngineStatsResponse(
|
||||
cache=stats.get("cache", {}),
|
||||
settings=stats.get("settings", {}),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/invalidate",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Invalidate Cache (Admin Only)",
|
||||
description="Invalidate context cache entries.",
|
||||
)
|
||||
async def invalidate_cache(
|
||||
project_id: Annotated[
|
||||
str | None, Query(description="Project to invalidate")
|
||||
] = None,
|
||||
pattern: Annotated[str | None, Query(description="Pattern to match")] = None,
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> None:
|
||||
"""Invalidate cache entries."""
|
||||
logger.info(
|
||||
"Cache invalidation by user %s: project=%s pattern=%s",
|
||||
current_user.id,
|
||||
project_id,
|
||||
pattern,
|
||||
)
|
||||
await engine.invalidate_cache(project_id=project_id, pattern=pattern)
|
||||
@@ -90,16 +90,19 @@ class ClaudeAdapter(ModelAdapter):
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts)
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
# Fallback for any unhandled context types - still escape content
|
||||
# to prevent XML injection if new types are added without updating adapter
|
||||
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
# System prompts are typically admin-controlled, but escape for safety
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||
@@ -107,16 +110,22 @@ class ClaudeAdapter(ModelAdapter):
|
||||
Format knowledge contexts as structured documents.
|
||||
|
||||
Each knowledge context becomes a document with source attribution.
|
||||
All content is XML-escaped to prevent injection attacks.
|
||||
"""
|
||||
parts = ["<reference_documents>"]
|
||||
|
||||
for ctx in contexts:
|
||||
source = self._escape_xml(ctx.source)
|
||||
content = ctx.content
|
||||
# Escape content to prevent XML injection
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||
|
||||
if score:
|
||||
parts.append(f'<document source="{source}" relevance="{score}">')
|
||||
# Escape score to prevent XML injection via metadata
|
||||
escaped_score = self._escape_xml(str(score))
|
||||
parts.append(
|
||||
f'<document source="{source}" relevance="{escaped_score}">'
|
||||
)
|
||||
else:
|
||||
parts.append(f'<document source="{source}">')
|
||||
|
||||
@@ -131,13 +140,16 @@ class ClaudeAdapter(ModelAdapter):
|
||||
Format conversation contexts as message history.
|
||||
|
||||
Uses role-based message tags for clear turn delineation.
|
||||
All content is XML-escaped to prevent prompt injection.
|
||||
"""
|
||||
parts = ["<conversation_history>"]
|
||||
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
role = self._escape_xml(ctx.metadata.get("role", "user"))
|
||||
# Escape content to prevent prompt injection via fake XML tags
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
parts.append(f'<message role="{role}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append(content)
|
||||
parts.append("</message>")
|
||||
|
||||
parts.append("</conversation_history>")
|
||||
@@ -148,19 +160,23 @@ class ClaudeAdapter(ModelAdapter):
|
||||
Format tool contexts as tool results.
|
||||
|
||||
Each tool result is wrapped with the tool name.
|
||||
All content is XML-escaped to prevent injection.
|
||||
"""
|
||||
parts = ["<tool_results>"]
|
||||
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
|
||||
status = ctx.metadata.get("status", "")
|
||||
|
||||
if status:
|
||||
parts.append(f'<tool_result name="{tool_name}" status="{status}">')
|
||||
parts.append(
|
||||
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
|
||||
)
|
||||
else:
|
||||
parts.append(f'<tool_result name="{tool_name}">')
|
||||
|
||||
parts.append(ctx.content)
|
||||
# Escape content to prevent injection
|
||||
parts.append(self._escape_xml_content(ctx.content))
|
||||
parts.append("</tool_result>")
|
||||
|
||||
parts.append("</tool_results>")
|
||||
@@ -176,3 +192,21 @@ class ClaudeAdapter(ModelAdapter):
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml_content(text: str) -> str:
|
||||
"""
|
||||
Escape XML special characters in element content.
|
||||
|
||||
This prevents XML injection attacks where malicious content
|
||||
could break out of XML tags or inject fake tags for prompt injection.
|
||||
|
||||
Only escapes &, <, > since quotes don't need escaping in content.
|
||||
|
||||
Args:
|
||||
text: Content text to escape
|
||||
|
||||
Returns:
|
||||
XML-safe content string
|
||||
"""
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
@@ -12,6 +12,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..adapters import get_adapter
|
||||
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from ..compression.truncation import ContextCompressor
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
@@ -156,20 +157,42 @@ class ContextPipeline:
|
||||
else:
|
||||
budget = self._allocator.create_budget_for_model(model)
|
||||
|
||||
# 1. Count tokens for all contexts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
# 1. Count tokens for all contexts (with timeout enforcement)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._ensure_token_counts(contexts, model),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during token counting",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
|
||||
# Check timeout
|
||||
# Check timeout (handles edge case where operation finished just at limit)
|
||||
self._check_timeout(start, timeout, "token counting")
|
||||
|
||||
# 2. Score and rank contexts
|
||||
# 2. Score and rank contexts (with timeout enforcement)
|
||||
scoring_start = time.perf_counter()
|
||||
ranking_result = await self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
)
|
||||
try:
|
||||
ranking_result = await asyncio.wait_for(
|
||||
self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during scoring/ranking",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||
|
||||
selected_contexts = ranking_result.selected_contexts
|
||||
@@ -179,12 +202,23 @@ class ContextPipeline:
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "scoring")
|
||||
|
||||
# 3. Compress if needed and enabled
|
||||
# 3. Compress if needed and enabled (with timeout enforcement)
|
||||
if compress and self._needs_compression(selected_contexts, budget):
|
||||
compression_start = time.perf_counter()
|
||||
selected_contexts = await self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
)
|
||||
try:
|
||||
selected_contexts = await asyncio.wait_for(
|
||||
self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during compression",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.compression_time_ms = (
|
||||
time.perf_counter() - compression_start
|
||||
) * 1000
|
||||
@@ -280,129 +314,18 @@ class ContextPipeline:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Groups contexts by type and applies model-specific formatting.
|
||||
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
||||
to format contexts optimally for each model family.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to format
|
||||
model: Target model name
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
# Group by type
|
||||
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
if ct not in by_type:
|
||||
by_type[ct] = []
|
||||
by_type[ct].append(context)
|
||||
|
||||
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
|
||||
type_order = [
|
||||
ContextType.SYSTEM,
|
||||
ContextType.TASK,
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.CONVERSATION,
|
||||
ContextType.TOOL,
|
||||
]
|
||||
|
||||
parts: list[str] = []
|
||||
for ct in type_order:
|
||||
if ct in by_type:
|
||||
formatted = self._format_type(by_type[ct], ct, model)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
model: str,
|
||||
) -> str:
|
||||
"""Format contexts of a specific type."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# Check if model prefers XML tags (Claude)
|
||||
use_xml = "claude" in model.lower()
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts, use_xml)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts, use_xml)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts, use_xml)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts, use_xml)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts, use_xml)
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
return content
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
if use_xml:
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
return f"## Current Task\n\n{content}"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format knowledge contexts."""
|
||||
if use_xml:
|
||||
parts = ["<reference_documents>"]
|
||||
for ctx in contexts:
|
||||
parts.append(f'<document source="{ctx.source}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</document>")
|
||||
parts.append("</reference_documents>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = ["## Reference Documents\n"]
|
||||
for ctx in contexts:
|
||||
parts.append(f"### Source: {ctx.source}\n")
|
||||
parts.append(ctx.content)
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format conversation contexts."""
|
||||
if use_xml:
|
||||
parts = ["<conversation_history>"]
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
parts.append(f'<message role="{role}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</message>")
|
||||
parts.append("</conversation_history>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = []
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user")
|
||||
parts.append(f"**{role.upper()}**: {ctx.content}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
||||
"""Format tool contexts."""
|
||||
if use_xml:
|
||||
parts = ["<tool_results>"]
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
parts.append(f'<tool_result name="{tool_name}">')
|
||||
parts.append(ctx.content)
|
||||
parts.append("</tool_result>")
|
||||
parts.append("</tool_results>")
|
||||
return "\n".join(parts)
|
||||
else:
|
||||
parts = ["## Recent Tool Results\n"]
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
parts.append(f"### Tool: {tool_name}\n")
|
||||
parts.append(f"```\n{ctx.content}\n```")
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
adapter = get_adapter(model)
|
||||
return adapter.format(contexts)
|
||||
|
||||
def _check_timeout(
|
||||
self,
|
||||
@@ -412,9 +335,28 @@ class ContextPipeline:
|
||||
) -> None:
|
||||
"""Check if timeout exceeded and raise if so."""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
if elapsed_ms > timeout_ms:
|
||||
if elapsed_ms >= timeout_ms:
|
||||
raise AssemblyTimeoutError(
|
||||
message=f"Context assembly timed out during {phase}",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
||||
"""
|
||||
Calculate remaining timeout in seconds for asyncio.wait_for.
|
||||
|
||||
Returns at least a small positive value to avoid immediate timeout
|
||||
edge cases with wait_for.
|
||||
|
||||
Args:
|
||||
start: Start time from time.perf_counter()
|
||||
timeout_ms: Total timeout in milliseconds
|
||||
|
||||
Returns:
|
||||
Remaining timeout in seconds (minimum 0.001)
|
||||
"""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
remaining_ms = timeout_ms - elapsed_ms
|
||||
# Return at least 1ms to avoid zero/negative timeout edge cases
|
||||
return max(remaining_ms / 1000.0, 0.001)
|
||||
|
||||
@@ -293,14 +293,18 @@ class BudgetAllocator:
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
# Calculate adjustment (limited by buffer)
|
||||
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
|
||||
if adjustment > 0:
|
||||
# Taking from buffer
|
||||
# Taking from buffer - limited by available buffer
|
||||
actual_adjustment = min(adjustment, budget.buffer)
|
||||
budget.buffer -= actual_adjustment
|
||||
else:
|
||||
# Returning to buffer
|
||||
actual_adjustment = adjustment
|
||||
# Returning to buffer - limited by current allocation of target type
|
||||
current_allocation = budget.get_allocation(context_type)
|
||||
# Can't return more than current allocation
|
||||
actual_adjustment = max(adjustment, -current_allocation)
|
||||
# Add returned tokens back to buffer (adjustment is negative, so subtract)
|
||||
budget.buffer -= actual_adjustment
|
||||
|
||||
# Apply to target type
|
||||
if context_type == "system":
|
||||
|
||||
@@ -95,19 +95,28 @@ class ContextCache:
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
model: str,
|
||||
project_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compute a fingerprint for a context assembly request.
|
||||
|
||||
The fingerprint is based on:
|
||||
- Project and agent IDs (for tenant isolation)
|
||||
- Context content hash and metadata (not full content for performance)
|
||||
- Query string
|
||||
- Target model
|
||||
|
||||
SECURITY: project_id and agent_id MUST be included to prevent
|
||||
cross-tenant cache pollution. Without these, one tenant could
|
||||
receive cached contexts from another tenant with the same query.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts
|
||||
query: Query string
|
||||
model: Model name
|
||||
project_id: Project ID for tenant isolation
|
||||
agent_id: Agent ID for tenant isolation
|
||||
|
||||
Returns:
|
||||
32-character hex fingerprint
|
||||
@@ -128,6 +137,9 @@ class ContextCache:
|
||||
)
|
||||
|
||||
data = {
|
||||
# CRITICAL: Include tenant identifiers for cache isolation
|
||||
"project_id": project_id or "",
|
||||
"agent_id": agent_id or "",
|
||||
"contexts": context_data,
|
||||
"query": query,
|
||||
"model": model,
|
||||
|
||||
@@ -19,6 +19,40 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _estimate_tokens(text: str, model: str | None = None) -> int:
|
||||
"""
|
||||
Estimate token count using model-specific character ratios.
|
||||
|
||||
Module-level function for reuse across classes. Uses the same ratios
|
||||
as TokenCalculator for consistency.
|
||||
|
||||
Args:
|
||||
text: Text to estimate tokens for
|
||||
model: Optional model name for model-specific ratios
|
||||
|
||||
Returns:
|
||||
Estimated token count (minimum 1)
|
||||
"""
|
||||
# Model-specific character ratios (chars per token)
|
||||
model_ratios = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
"gemini": 4.0,
|
||||
}
|
||||
default_ratio = 4.0
|
||||
|
||||
ratio = default_ratio
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for model_prefix, model_ratio in model_ratios.items():
|
||||
if model_prefix in model_lower:
|
||||
ratio = model_ratio
|
||||
break
|
||||
|
||||
return max(1, int(len(text) / ratio))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationResult:
|
||||
"""Result of truncation operation."""
|
||||
@@ -284,8 +318,8 @@ class TruncationStrategy:
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
|
||||
# Fallback estimation
|
||||
return max(1, len(text) // 4)
|
||||
# Fallback estimation with model-specific ratios
|
||||
return _estimate_tokens(text, model)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
@@ -415,4 +449,5 @@ class ContextCompressor:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
return max(1, len(text) // 4)
|
||||
# Use model-specific estimation for consistency
|
||||
return _estimate_tokens(text, model)
|
||||
|
||||
@@ -149,10 +149,11 @@ class ContextSettings(BaseSettings):
|
||||
|
||||
# Performance settings
|
||||
max_assembly_time_ms: int = Field(
|
||||
default=100,
|
||||
default=2000,
|
||||
ge=10,
|
||||
le=5000,
|
||||
description="Maximum time for context assembly in milliseconds",
|
||||
le=30000,
|
||||
description="Maximum time for context assembly in milliseconds. "
|
||||
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
|
||||
)
|
||||
parallel_scoring: bool = Field(
|
||||
default=True,
|
||||
|
||||
@@ -212,7 +212,10 @@ class ContextEngine:
|
||||
# Check cache if enabled
|
||||
fingerprint: str | None = None
|
||||
if use_cache and self._cache.is_enabled:
|
||||
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
||||
# Include project_id and agent_id for tenant isolation
|
||||
fingerprint = self._cache.compute_fingerprint(
|
||||
contexts, query, model, project_id=project_id, agent_id=agent_id
|
||||
)
|
||||
cached = await self._cache.get_assembled(fingerprint)
|
||||
if cached:
|
||||
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||
from ..types import BaseContext, ContextPriority
|
||||
|
||||
@@ -127,9 +128,25 @@ class ContextRanker:
|
||||
excluded: list[ScoredContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
# Calculate the usable budget (total minus reserved portions)
|
||||
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||
|
||||
# Guard against invalid budget configuration
|
||||
if usable_budget <= 0:
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"Invalid budget configuration: no usable tokens available. "
|
||||
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
||||
f"buffer={budget.buffer}"
|
||||
),
|
||||
allocated=budget.total,
|
||||
requested=0,
|
||||
context_type="CONFIGURATION_ERROR",
|
||||
)
|
||||
|
||||
# First, try to fit required contexts
|
||||
for sc in required:
|
||||
token_count = sc.context.token_count or 0
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
@@ -137,7 +154,20 @@ class ContextRanker:
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
# Force-fit CRITICAL contexts if needed
|
||||
# Force-fit CRITICAL contexts if needed, but check total budget first
|
||||
if total_tokens + token_count > usable_budget:
|
||||
# Even CRITICAL contexts cannot exceed total model context window
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"CRITICAL contexts exceed total budget. "
|
||||
f"Context '{sc.context.source}' ({token_count} tokens) "
|
||||
f"would exceed usable budget of {usable_budget} tokens."
|
||||
),
|
||||
allocated=usable_budget,
|
||||
requested=total_tokens + token_count,
|
||||
context_type="CRITICAL_OVERFLOW",
|
||||
)
|
||||
|
||||
budget.allocate(context_type, token_count, force=True)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
@@ -148,7 +178,7 @@ class ContextRanker:
|
||||
|
||||
# Then, greedily add optional contexts
|
||||
for sc in optional:
|
||||
token_count = sc.context.token_count or 0
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
@@ -215,13 +245,43 @@ class ContextRanker:
|
||||
total_tokens = 0
|
||||
|
||||
for sc in scored_contexts:
|
||||
token_count = sc.context.token_count or 0
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
if total_tokens + token_count <= max_tokens:
|
||||
selected.append(sc.context)
|
||||
total_tokens += token_count
|
||||
|
||||
return selected
|
||||
|
||||
def _get_valid_token_count(self, context: BaseContext) -> int:
|
||||
"""
|
||||
Get validated token count from a context.
|
||||
|
||||
Ensures token_count is set (not None) and non-negative to prevent
|
||||
budget bypass attacks where:
|
||||
- None would be treated as 0 (allowing huge contexts to slip through)
|
||||
- Negative values would corrupt budget tracking
|
||||
|
||||
Args:
|
||||
context: Context to get token count from
|
||||
|
||||
Returns:
|
||||
Valid non-negative token count
|
||||
|
||||
Raises:
|
||||
ValueError: If token_count is None or negative
|
||||
"""
|
||||
if context.token_count is None:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has no token count. "
|
||||
"Ensure _ensure_token_counts() is called before ranking."
|
||||
)
|
||||
if context.token_count < 0:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has invalid negative token count: "
|
||||
f"{context.token_count}"
|
||||
)
|
||||
return context.token_count
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
@@ -266,6 +326,7 @@ class ContextRanker:
|
||||
if type_name not in by_type:
|
||||
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||
by_type[type_name]["count"] += 1
|
||||
# Use validated token count (already validated during ranking)
|
||||
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||
|
||||
return by_type
|
||||
|
||||
@@ -6,9 +6,9 @@ Combines multiple scoring strategies with configurable weights.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext
|
||||
@@ -91,11 +91,11 @@ class CompositeScorer:
|
||||
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||
|
||||
# Per-context locks to prevent race conditions during parallel scoring
|
||||
# Uses WeakValueDictionary so locks are garbage collected when not in use
|
||||
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = (
|
||||
WeakValueDictionary()
|
||||
)
|
||||
# Uses dict with (lock, last_used_time) tuples for cleanup
|
||||
self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
|
||||
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
|
||||
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
@@ -141,7 +141,8 @@ class CompositeScorer:
|
||||
Get or create a lock for a specific context.
|
||||
|
||||
Thread-safe access to per-context locks prevents race conditions
|
||||
when the same context is scored concurrently.
|
||||
when the same context is scored concurrently. Includes automatic
|
||||
cleanup of old locks to prevent memory growth.
|
||||
|
||||
Args:
|
||||
context_id: The context ID to get a lock for
|
||||
@@ -149,25 +150,78 @@ class CompositeScorer:
|
||||
Returns:
|
||||
asyncio.Lock for the context
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# Fast path: check if lock exists without acquiring main lock
|
||||
if context_id in self._context_locks:
|
||||
lock = self._context_locks.get(context_id)
|
||||
if lock is not None:
|
||||
# NOTE: We only READ here - no writes to avoid race conditions
|
||||
# with cleanup. The timestamp will be updated in the slow path
|
||||
# if the lock is still valid.
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Return the lock but defer timestamp update to avoid race
|
||||
# The lock is still valid; timestamp update is best-effort
|
||||
return lock
|
||||
|
||||
# Slow path: create lock or update timestamp while holding main lock
|
||||
async with self._locks_lock:
|
||||
# Double-check after acquiring lock - entry may have been
|
||||
# created by another coroutine or deleted by cleanup
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Safe to update timestamp here since we hold the lock
|
||||
self._context_locks[context_id] = (lock, now)
|
||||
return lock
|
||||
|
||||
# Slow path: create lock while holding main lock
|
||||
async with self._locks_lock:
|
||||
# Double-check after acquiring lock
|
||||
if context_id in self._context_locks:
|
||||
lock = self._context_locks.get(context_id)
|
||||
if lock is not None:
|
||||
return lock
|
||||
# Cleanup old locks if we have too many
|
||||
if len(self._context_locks) >= self._max_locks:
|
||||
self._cleanup_old_locks(now)
|
||||
|
||||
# Create new lock
|
||||
new_lock = asyncio.Lock()
|
||||
self._context_locks[context_id] = new_lock
|
||||
self._context_locks[context_id] = (new_lock, now)
|
||||
return new_lock
|
||||
|
||||
def _cleanup_old_locks(self, now: float) -> None:
|
||||
"""
|
||||
Remove old locks that haven't been used recently.
|
||||
|
||||
Called while holding _locks_lock. Removes locks older than _lock_ttl,
|
||||
but only if they're not currently held.
|
||||
|
||||
Args:
|
||||
now: Current timestamp for age calculation
|
||||
"""
|
||||
cutoff = now - self._lock_ttl
|
||||
to_remove = []
|
||||
|
||||
for context_id, (lock, last_used) in self._context_locks.items():
|
||||
# Only remove if old AND not currently held
|
||||
if last_used < cutoff and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
# Remove oldest 50% if still over limit after TTL filtering
|
||||
if len(self._context_locks) - len(to_remove) >= self._max_locks:
|
||||
# Sort by last used time and mark oldest for removal
|
||||
sorted_entries = sorted(
|
||||
self._context_locks.items(),
|
||||
key=lambda x: x[1][1], # Sort by last_used time
|
||||
)
|
||||
# Remove oldest 50% that aren't locked
|
||||
target_remove = len(self._context_locks) // 2
|
||||
for context_id, (lock, _) in sorted_entries:
|
||||
if len(to_remove) >= target_remove:
|
||||
break
|
||||
if context_id not in to_remove and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
for context_id in to_remove:
|
||||
del self._context_locks[context_id]
|
||||
|
||||
if to_remove:
|
||||
logger.debug(f"Cleaned up {len(to_remove)} context locks")
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
|
||||
@@ -24,6 +24,9 @@ from ..models import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sentinel for distinguishing "no argument passed" from "explicitly passing None"
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""
|
||||
@@ -142,8 +145,10 @@ class AuditLogger:
|
||||
# Add hash chain for tamper detection
|
||||
if self._enable_hash_chain:
|
||||
event_hash = self._compute_hash(event)
|
||||
sanitized_details["_hash"] = event_hash
|
||||
sanitized_details["_prev_hash"] = self._last_hash
|
||||
# Modify event.details directly (not sanitized_details)
|
||||
# to ensure the hash is stored on the actual event
|
||||
event.details["_hash"] = event_hash
|
||||
event.details["_prev_hash"] = self._last_hash
|
||||
self._last_hash = event_hash
|
||||
|
||||
self._buffer.append(event)
|
||||
@@ -415,7 +420,8 @@ class AuditLogger:
|
||||
)
|
||||
|
||||
if stored_hash:
|
||||
computed = self._compute_hash(event)
|
||||
# Pass prev_hash to compute hash with correct chain position
|
||||
computed = self._compute_hash(event, prev_hash=prev_hash)
|
||||
if computed != stored_hash:
|
||||
issues.append(
|
||||
f"Hash mismatch at event {event.id}: "
|
||||
@@ -462,9 +468,23 @@ class AuditLogger:
|
||||
|
||||
return sanitized
|
||||
|
||||
def _compute_hash(self, event: AuditEvent) -> str:
|
||||
"""Compute hash for an event (excluding hash fields)."""
|
||||
data = {
|
||||
def _compute_hash(
|
||||
self, event: AuditEvent, prev_hash: str | None | object = _UNSET
|
||||
) -> str:
|
||||
"""Compute hash for an event (excluding hash fields).
|
||||
|
||||
Args:
|
||||
event: The audit event to hash.
|
||||
prev_hash: Optional previous hash to use instead of self._last_hash.
|
||||
Pass this during verification to use the correct chain.
|
||||
Use None explicitly to indicate no previous hash.
|
||||
"""
|
||||
# Use passed prev_hash if explicitly provided, otherwise use instance state
|
||||
effective_prev: str | None = (
|
||||
self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment]
|
||||
)
|
||||
|
||||
data: dict[str, str | dict[str, str] | None] = {
|
||||
"id": event.id,
|
||||
"event_type": event.event_type.value,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
@@ -480,8 +500,8 @@ class AuditLogger:
|
||||
"correlation_id": event.correlation_id,
|
||||
}
|
||||
|
||||
if self._last_hash:
|
||||
data["_prev_hash"] = self._last_hash
|
||||
if effective_prev:
|
||||
data["_prev_hash"] = effective_prev
|
||||
|
||||
serialized = json.dumps(data, sort_keys=True, default=str)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
466
backend/tests/api/routes/test_context.py
Normal file
466
backend/tests/api/routes/test_context.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
Tests for Context Management API Routes.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.services.context import (
|
||||
AssembledContext,
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
ContextEngine,
|
||||
TokenBudget,
|
||||
)
|
||||
from app.services.mcp import MCPClientManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_client():
|
||||
"""Create a mock MCP client manager."""
|
||||
client = MagicMock(spec=MCPClientManager)
|
||||
client.is_initialized = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context_engine(mock_mcp_client):
|
||||
"""Create a mock ContextEngine."""
|
||||
engine = MagicMock(spec=ContextEngine)
|
||||
engine._mcp = mock_mcp_client
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_superuser():
|
||||
"""Create a mock superuser."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = "00000000-0000-0000-0000-000000000001"
|
||||
user.is_superuser = True
|
||||
user.email = "admin@example.com"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_mcp_client, mock_context_engine, mock_superuser):
|
||||
"""Create a FastAPI test client with mocked dependencies."""
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.api.routes.context import get_context_engine
|
||||
from app.services.mcp import get_mcp_client
|
||||
|
||||
# Override dependencies
|
||||
async def override_get_mcp_client():
|
||||
return mock_mcp_client
|
||||
|
||||
async def override_get_context_engine():
|
||||
return mock_context_engine
|
||||
|
||||
async def override_require_superuser():
|
||||
return mock_superuser
|
||||
|
||||
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
|
||||
app.dependency_overrides[get_context_engine] = override_get_context_engine
|
||||
app.dependency_overrides[require_superuser] = override_require_superuser
|
||||
|
||||
with patch("app.main.check_database_health", return_value=True):
|
||||
yield TestClient(app)
|
||||
|
||||
# Clean up
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestContextHealth:
|
||||
"""Tests for GET /context/health endpoint."""
|
||||
|
||||
def test_health_check_success(self, client, mock_context_engine, mock_mcp_client):
|
||||
"""Test context engine health check."""
|
||||
mock_context_engine.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"cache": {"hits": 10, "misses": 5},
|
||||
"settings": {"cache_enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "mcp_connected" in data
|
||||
assert "cache_enabled" in data
|
||||
|
||||
|
||||
class TestAssembleContext:
|
||||
"""Tests for POST /context/assemble endpoint."""
|
||||
|
||||
def test_assemble_context_success(self, client, mock_context_engine):
|
||||
"""Test successful context assembly."""
|
||||
# Create mock assembled context
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Assembled context content"
|
||||
mock_result.total_tokens = 500
|
||||
mock_result.context_count = 2
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 50.5
|
||||
mock_result.metadata = {}
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "What is the auth flow?",
|
||||
"model": "claude-3-sonnet",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["content"] == "Assembled context content"
|
||||
assert data["total_tokens"] == 500
|
||||
assert data["context_count"] == 2
|
||||
assert data["compressed"] is False
|
||||
assert "budget_used_percent" in data
|
||||
|
||||
def test_assemble_context_with_conversation(self, client, mock_context_engine):
|
||||
"""Test context assembly with conversation history."""
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Context with history"
|
||||
mock_result.total_tokens = 800
|
||||
mock_result.context_count = 1
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 30.0
|
||||
mock_result.metadata = {}
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "Continue the discussion",
|
||||
"conversation_history": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
call_args = mock_context_engine.assemble_context.call_args
|
||||
assert call_args.kwargs["conversation_history"] == [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
def test_assemble_context_with_tool_results(self, client, mock_context_engine):
|
||||
"""Test context assembly with tool results."""
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Context with tools"
|
||||
mock_result.total_tokens = 600
|
||||
mock_result.context_count = 1
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 25.0
|
||||
mock_result.metadata = {}
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "What did the search find?",
|
||||
"tool_results": [
|
||||
{
|
||||
"tool_name": "search_knowledge",
|
||||
"content": {"results": ["item1", "item2"]},
|
||||
"status": "success",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
call_args = mock_context_engine.assemble_context.call_args
|
||||
assert len(call_args.kwargs["tool_results"]) == 1
|
||||
|
||||
def test_assemble_context_timeout(self, client, mock_context_engine):
|
||||
"""Test context assembly timeout error."""
|
||||
mock_context_engine.assemble_context = AsyncMock(
|
||||
side_effect=AssemblyTimeoutError("Assembly exceeded 5000ms limit")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
|
||||
|
||||
def test_assemble_context_budget_exceeded(self, client, mock_context_engine):
|
||||
"""Test context assembly budget exceeded error."""
|
||||
mock_context_engine.assemble_context = AsyncMock(
|
||||
side_effect=BudgetExceededError("Token budget exceeded: 5000 > 4000")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
||||
|
||||
def test_assemble_context_validation_error(self, client):
|
||||
"""Test context assembly with invalid request."""
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={}, # Missing required fields
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestCountTokens:
|
||||
"""Tests for POST /context/count-tokens endpoint."""
|
||||
|
||||
def test_count_tokens_success(self, client, mock_context_engine):
|
||||
"""Test successful token counting."""
|
||||
mock_context_engine.count_tokens = AsyncMock(return_value=42)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={
|
||||
"content": "This is some test content.",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["token_count"] == 42
|
||||
assert data["model"] == "claude-3-sonnet"
|
||||
|
||||
def test_count_tokens_without_model(self, client, mock_context_engine):
|
||||
"""Test token counting without specifying model."""
|
||||
mock_context_engine.count_tokens = AsyncMock(return_value=100)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={"content": "Some content to count."},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["token_count"] == 100
|
||||
assert data["model"] is None
|
||||
|
||||
|
||||
class TestGetBudget:
|
||||
"""Tests for GET /context/budget/{model} endpoint."""
|
||||
|
||||
def test_get_budget_success(self, client, mock_context_engine):
|
||||
"""Test getting token budget for a model."""
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=100000,
|
||||
system=10000,
|
||||
knowledge=40000,
|
||||
conversation=30000,
|
||||
tools=10000,
|
||||
response_reserve=10000,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/budget/claude-3-opus")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["model"] == "claude-3-opus"
|
||||
assert data["total_tokens"] == 100000
|
||||
assert data["system_tokens"] == 10000
|
||||
assert data["knowledge_tokens"] == 40000
|
||||
|
||||
def test_get_budget_with_max_tokens(self, client, mock_context_engine):
|
||||
"""Test getting budget with custom max tokens."""
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=2000,
|
||||
system=200,
|
||||
knowledge=800,
|
||||
conversation=600,
|
||||
tools=200,
|
||||
response_reserve=200,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/budget/gpt-4?max_tokens=2000")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total_tokens"] == 2000
|
||||
|
||||
|
||||
class TestGetStats:
|
||||
"""Tests for GET /context/stats endpoint."""
|
||||
|
||||
def test_get_stats_success(self, client, mock_context_engine):
|
||||
"""Test getting engine statistics."""
|
||||
mock_context_engine.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"cache": {
|
||||
"hits": 100,
|
||||
"misses": 25,
|
||||
"hit_rate": 0.8,
|
||||
},
|
||||
"settings": {
|
||||
"compression_threshold": 0.9,
|
||||
"max_assembly_time_ms": 5000,
|
||||
"cache_enabled": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/stats")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["cache"]["hits"] == 100
|
||||
assert data["settings"]["cache_enabled"] is True
|
||||
|
||||
|
||||
class TestInvalidateCache:
|
||||
"""Tests for POST /context/cache/invalidate endpoint."""
|
||||
|
||||
def test_invalidate_cache_by_project(self, client, mock_context_engine):
|
||||
"""Test cache invalidation by project ID."""
|
||||
mock_context_engine.invalidate_cache = AsyncMock(return_value=5)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/cache/invalidate?project_id=test-project"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
mock_context_engine.invalidate_cache.assert_called_once()
|
||||
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "test-project"
|
||||
|
||||
def test_invalidate_cache_by_pattern(self, client, mock_context_engine):
|
||||
"""Test cache invalidation by pattern."""
|
||||
mock_context_engine.invalidate_cache = AsyncMock(return_value=10)
|
||||
|
||||
response = client.post("/api/v1/context/cache/invalidate?pattern=*auth*")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
mock_context_engine.invalidate_cache.assert_called_once()
|
||||
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
|
||||
assert call_kwargs["pattern"] == "*auth*"
|
||||
|
||||
def test_invalidate_cache_all(self, client, mock_context_engine):
|
||||
"""Test invalidating all cache entries."""
|
||||
mock_context_engine.invalidate_cache = AsyncMock(return_value=100)
|
||||
|
||||
response = client.post("/api/v1/context/cache/invalidate")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
|
||||
class TestContextEndpointsEdgeCases:
|
||||
"""Edge case tests for Context endpoints."""
|
||||
|
||||
def test_context_content_type(self, client, mock_context_engine):
|
||||
"""Test that endpoints return JSON content type."""
|
||||
mock_context_engine.get_stats = AsyncMock(
|
||||
return_value={"cache": {}, "settings": {}}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/health")
|
||||
|
||||
assert "application/json" in response.headers["content-type"]
|
||||
|
||||
def test_assemble_context_with_knowledge_query(self, client, mock_context_engine):
|
||||
"""Test context assembly with knowledge base query."""
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Context with knowledge"
|
||||
mock_result.total_tokens = 1000
|
||||
mock_result.context_count = 3
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 100.0
|
||||
mock_result.metadata = {
|
||||
"compressed_contexts": 1
|
||||
} # Indicates compression happened
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "How does authentication work?",
|
||||
"knowledge_query": "authentication flow implementation",
|
||||
"knowledge_limit": 5,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
call_kwargs = mock_context_engine.assemble_context.call_args.kwargs
|
||||
assert call_kwargs["knowledge_query"] == "authentication flow implementation"
|
||||
assert call_kwargs["knowledge_limit"] == 5
|
||||
646
backend/tests/e2e/test_agent_workflows.py
Normal file
646
backend/tests/e2e/test_agent_workflows.py
Normal file
@@ -0,0 +1,646 @@
|
||||
"""
|
||||
Agent E2E Workflow Tests.
|
||||
|
||||
Tests complete workflows for AI agents including:
|
||||
- Agent type management (admin-only)
|
||||
- Agent instance spawning and lifecycle
|
||||
- Agent status transitions (pause/resume/terminate)
|
||||
- Authorization and access control
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
class TestAgentTypesAdminWorkflows:
|
||||
"""Test agent type management (admin-only operations)."""
|
||||
|
||||
async def test_create_agent_type_requires_superuser(self, e2e_client):
|
||||
"""Test that creating agent types requires superuser privileges."""
|
||||
# Register regular user
|
||||
email = f"regular-{uuid4().hex[:8]}@example.com"
|
||||
password = "RegularPass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Regular",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Try to create agent type
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"name": "Test Agent",
|
||||
"slug": f"test-agent-{uuid4().hex[:8]}",
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_superuser_can_create_agent_type(self, e2e_client, e2e_superuser):
|
||||
"""Test that superuser can create and manage agent types."""
|
||||
slug = f"test-type-{uuid4().hex[:8]}"
|
||||
|
||||
# Create agent type
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Product Owner Agent",
|
||||
"slug": slug,
|
||||
"description": "A product owner agent for requirements gathering",
|
||||
"expertise": ["requirements", "user_stories", "prioritization"],
|
||||
"personality_prompt": "You are a product owner focused on delivering value.",
|
||||
"primary_model": "claude-3-opus",
|
||||
"fallback_models": ["claude-3-sonnet"],
|
||||
"model_params": {"temperature": 0.7, "max_tokens": 4000},
|
||||
"mcp_servers": ["knowledge-base"],
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert create_resp.status_code == 201, f"Failed: {create_resp.text}"
|
||||
agent_type = create_resp.json()
|
||||
|
||||
assert agent_type["name"] == "Product Owner Agent"
|
||||
assert agent_type["slug"] == slug
|
||||
assert agent_type["primary_model"] == "claude-3-opus"
|
||||
assert agent_type["is_active"] is True
|
||||
assert "requirements" in agent_type["expertise"]
|
||||
|
||||
async def test_list_agent_types_public(self, e2e_client, e2e_superuser):
|
||||
"""Test that any authenticated user can list agent types."""
|
||||
# First create an agent type as superuser
|
||||
slug = f"list-test-{uuid4().hex[:8]}"
|
||||
await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": f"List Test Agent {slug}",
|
||||
"slug": slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
# Register regular user
|
||||
email = f"lister-{uuid4().hex[:8]}@example.com"
|
||||
password = "ListerPass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "List",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# List agent types as regular user
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/agent-types",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert list_resp.status_code == 200
|
||||
data = list_resp.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
assert data["pagination"]["total"] >= 1
|
||||
|
||||
async def test_get_agent_type_by_slug(self, e2e_client, e2e_superuser):
|
||||
"""Test getting agent type by slug."""
|
||||
slug = f"slug-test-{uuid4().hex[:8]}"
|
||||
|
||||
# Create agent type
|
||||
await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": f"Slug Test {slug}",
|
||||
"slug": slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
# Get by slug (route is /slug/{slug}, not /by-slug/{slug})
|
||||
get_resp = await e2e_client.get(
|
||||
f"/api/v1/agent-types/slug/{slug}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert get_resp.status_code == 200
|
||||
data = get_resp.json()
|
||||
assert data["slug"] == slug
|
||||
|
||||
async def test_update_agent_type(self, e2e_client, e2e_superuser):
|
||||
"""Test updating an agent type."""
|
||||
slug = f"update-test-{uuid4().hex[:8]}"
|
||||
|
||||
# Create agent type
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"slug": slug,
|
||||
"personality_prompt": "Original prompt.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = create_resp.json()["id"]
|
||||
|
||||
# Update agent type
|
||||
update_resp = await e2e_client.patch(
|
||||
f"/api/v1/agent-types/{agent_type_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Updated Name",
|
||||
"description": "Added description",
|
||||
},
|
||||
)
|
||||
|
||||
assert update_resp.status_code == 200
|
||||
updated = update_resp.json()
|
||||
assert updated["name"] == "Updated Name"
|
||||
assert updated["description"] == "Added description"
|
||||
assert updated["personality_prompt"] == "Original prompt." # Unchanged
|
||||
|
||||
|
||||
class TestAgentInstanceWorkflows:
|
||||
"""Test agent instance spawning and lifecycle."""
|
||||
|
||||
async def test_spawn_agent_workflow(self, e2e_client, e2e_superuser):
|
||||
"""Test complete workflow: create type -> create project -> spawn agent."""
|
||||
# 1. Create agent type as superuser
|
||||
type_slug = f"spawn-test-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Spawn Test Agent",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "You are a helpful agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
assert type_resp.status_code == 201
|
||||
agent_type = type_resp.json()
|
||||
agent_type_id = agent_type["id"]
|
||||
|
||||
# 2. Create a project (superuser can create projects too)
|
||||
project_slug = f"spawn-test-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Spawn Test Project", "slug": project_slug},
|
||||
)
|
||||
assert project_resp.status_code == 201
|
||||
project = project_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# 3. Spawn agent instance
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "My PO Agent",
|
||||
},
|
||||
)
|
||||
|
||||
assert spawn_resp.status_code == 201, f"Failed: {spawn_resp.text}"
|
||||
agent = spawn_resp.json()
|
||||
|
||||
assert agent["name"] == "My PO Agent"
|
||||
assert agent["status"] == "idle"
|
||||
assert agent["project_id"] == project_id
|
||||
assert agent["agent_type_id"] == agent_type_id
|
||||
|
||||
async def test_list_project_agents(self, e2e_client, e2e_superuser):
|
||||
"""Test listing agents in a project."""
|
||||
# Setup: Create agent type and project
|
||||
type_slug = f"list-agents-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "List Agents Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
project_slug = f"list-agents-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "List Agents Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
# Spawn multiple agents
|
||||
for i in range(3):
|
||||
await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": f"Agent {i + 1}",
|
||||
},
|
||||
)
|
||||
|
||||
# List agents
|
||||
list_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert list_resp.status_code == 200
|
||||
data = list_resp.json()
|
||||
assert data["pagination"]["total"] == 3
|
||||
assert len(data["data"]) == 3
|
||||
|
||||
|
||||
class TestAgentLifecycle:
|
||||
"""Test agent lifecycle operations (pause/resume/terminate)."""
|
||||
|
||||
async def test_agent_pause_and_resume(self, e2e_client, e2e_superuser):
|
||||
"""Test pausing and resuming an agent."""
|
||||
# Setup: Create agent type, project, and agent
|
||||
type_slug = f"pause-test-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Pause Test Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
project_slug = f"pause-test-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Pause Test Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "Pausable Agent",
|
||||
},
|
||||
)
|
||||
agent_id = spawn_resp.json()["id"]
|
||||
assert spawn_resp.json()["status"] == "idle"
|
||||
|
||||
# Pause agent
|
||||
pause_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}/pause",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
assert pause_resp.status_code == 200, f"Failed: {pause_resp.text}"
|
||||
assert pause_resp.json()["status"] == "paused"
|
||||
|
||||
# Resume agent
|
||||
resume_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
assert resume_resp.status_code == 200, f"Failed: {resume_resp.text}"
|
||||
assert resume_resp.json()["status"] == "idle"
|
||||
|
||||
async def test_agent_terminate(self, e2e_client, e2e_superuser):
|
||||
"""Test terminating an agent."""
|
||||
# Setup
|
||||
type_slug = f"terminate-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Terminate Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
project_slug = f"terminate-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Terminate Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "To Be Terminated",
|
||||
},
|
||||
)
|
||||
agent_id = spawn_resp.json()["id"]
|
||||
|
||||
# Terminate agent (returns MessageResponse, not agent status)
|
||||
terminate_resp = await e2e_client.delete(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert terminate_resp.status_code == 200
|
||||
assert "message" in terminate_resp.json()
|
||||
|
||||
# Verify terminated agent cannot be resumed (returns 400 or 422)
|
||||
resume_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
assert resume_resp.status_code in [400, 422] # Invalid transition
|
||||
|
||||
|
||||
class TestAgentAccessControl:
|
||||
"""Test agent access control and authorization."""
|
||||
|
||||
async def test_user_cannot_access_other_project_agents(
|
||||
self, e2e_client, e2e_superuser
|
||||
):
|
||||
"""Test that users cannot access agents in projects they don't own."""
|
||||
# Superuser creates agent type
|
||||
type_slug = f"access-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Access Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
# Superuser creates project and spawns agent
|
||||
project_slug = f"protected-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Protected Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "Protected Agent",
|
||||
},
|
||||
)
|
||||
agent_id = spawn_resp.json()["id"]
|
||||
|
||||
# Create a different user
|
||||
email = f"other-user-{uuid4().hex[:8]}@example.com"
|
||||
password = "OtherPass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Other",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
other_tokens = login_resp.json()
|
||||
|
||||
# Other user tries to access the agent
|
||||
get_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||
headers={"Authorization": f"Bearer {other_tokens['access_token']}"},
|
||||
)
|
||||
|
||||
# Should be forbidden or not found
|
||||
assert get_resp.status_code in [403, 404]
|
||||
|
||||
async def test_cannot_spawn_with_inactive_agent_type(
|
||||
self, e2e_client, e2e_superuser
|
||||
):
|
||||
"""Test that agents cannot be spawned from inactive agent types."""
|
||||
# Create agent type
|
||||
type_slug = f"inactive-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Inactive Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
# Deactivate the agent type
|
||||
await e2e_client.patch(
|
||||
f"/api/v1/agent-types/{agent_type_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"is_active": False},
|
||||
)
|
||||
|
||||
# Create project
|
||||
project_slug = f"inactive-spawn-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Inactive Spawn Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
# Try to spawn agent with inactive type
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "Should Fail",
|
||||
},
|
||||
)
|
||||
|
||||
# 422 is correct for validation errors per REST conventions
|
||||
assert spawn_resp.status_code == 422
|
||||
|
||||
|
||||
class TestAgentMetrics:
|
||||
"""Test agent metrics endpoint."""
|
||||
|
||||
async def test_get_agent_metrics(self, e2e_client, e2e_superuser):
|
||||
"""Test retrieving agent metrics."""
|
||||
# Setup
|
||||
type_slug = f"metrics-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Metrics Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
project_slug = f"metrics-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Metrics Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "Metrics Agent",
|
||||
},
|
||||
)
|
||||
agent_id = spawn_resp.json()["id"]
|
||||
|
||||
# Get metrics
|
||||
metrics_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}/metrics",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert metrics_resp.status_code == 200
|
||||
metrics = metrics_resp.json()
|
||||
|
||||
# Verify AgentInstanceMetrics structure
|
||||
assert "total_instances" in metrics
|
||||
assert "active_instances" in metrics
|
||||
assert "idle_instances" in metrics
|
||||
assert "total_tasks_completed" in metrics
|
||||
assert "total_tokens_used" in metrics
|
||||
assert "total_cost_incurred" in metrics
|
||||
460
backend/tests/e2e/test_mcp_workflows.py
Normal file
460
backend/tests/e2e/test_mcp_workflows.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
MCP and Context Engine E2E Workflow Tests.
|
||||
|
||||
Tests complete workflows involving MCP servers and the Context Engine
|
||||
against real PostgreSQL. These tests verify:
|
||||
- MCP server listing and tool discovery
|
||||
- Context engine operations
|
||||
- Admin-only MCP operations with proper authentication
|
||||
- Error handling for MCP operations
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
class TestMCPServerDiscovery:
|
||||
"""Test MCP server listing and discovery workflows."""
|
||||
|
||||
async def test_list_mcp_servers(self, e2e_client):
|
||||
"""Test listing MCP servers returns expected configuration."""
|
||||
response = await e2e_client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
# Should have servers configured
|
||||
assert "servers" in data
|
||||
assert "total" in data
|
||||
assert isinstance(data["servers"], list)
|
||||
|
||||
# Should have at least llm-gateway and knowledge-base
|
||||
server_names = [s["name"] for s in data["servers"]]
|
||||
assert "llm-gateway" in server_names
|
||||
assert "knowledge-base" in server_names
|
||||
|
||||
async def test_list_all_mcp_tools(self, e2e_client):
|
||||
"""Test listing all tools from all MCP servers."""
|
||||
response = await e2e_client.get("/api/v1/mcp/tools")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "tools" in data
|
||||
assert "total" in data
|
||||
assert isinstance(data["tools"], list)
|
||||
|
||||
async def test_mcp_health_check(self, e2e_client):
|
||||
"""Test MCP health check returns server status."""
|
||||
response = await e2e_client.get("/api/v1/mcp/health")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "servers" in data
|
||||
assert "healthy_count" in data
|
||||
assert "unhealthy_count" in data
|
||||
assert "total" in data
|
||||
|
||||
async def test_list_circuit_breakers(self, e2e_client):
|
||||
"""Test listing circuit breaker status."""
|
||||
response = await e2e_client.get("/api/v1/mcp/circuit-breakers")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "circuit_breakers" in data
|
||||
assert isinstance(data["circuit_breakers"], list)
|
||||
|
||||
|
||||
class TestMCPServerTools:
|
||||
"""Test MCP server tool listing."""
|
||||
|
||||
async def test_list_llm_gateway_tools(self, e2e_client):
|
||||
"""Test listing tools from LLM Gateway server."""
|
||||
response = await e2e_client.get("/api/v1/mcp/servers/llm-gateway/tools")
|
||||
|
||||
# May return 200 with tools or 404 if server not connected
|
||||
assert response.status_code in [200, 404, 502]
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
assert "total" in data
|
||||
|
||||
async def test_list_knowledge_base_tools(self, e2e_client):
|
||||
"""Test listing tools from Knowledge Base server."""
|
||||
response = await e2e_client.get("/api/v1/mcp/servers/knowledge-base/tools")
|
||||
|
||||
# May return 200 with tools or 404/502 if server not connected
|
||||
assert response.status_code in [200, 404, 502]
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
assert "total" in data
|
||||
|
||||
async def test_invalid_server_returns_404(self, e2e_client):
|
||||
"""Test that invalid server name returns 404."""
|
||||
response = await e2e_client.get("/api/v1/mcp/servers/nonexistent-server/tools")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestContextEngineWorkflows:
|
||||
"""Test Context Engine operations."""
|
||||
|
||||
async def test_context_engine_health(self, e2e_client):
|
||||
"""Test context engine health endpoint."""
|
||||
response = await e2e_client.get("/api/v1/context/health")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "healthy"
|
||||
assert "mcp_connected" in data
|
||||
assert "cache_enabled" in data
|
||||
|
||||
async def test_get_token_budget_claude_sonnet(self, e2e_client):
|
||||
"""Test getting token budget for Claude 3 Sonnet."""
|
||||
response = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert data["model"] == "claude-3-sonnet"
|
||||
assert "total_tokens" in data
|
||||
assert "system_tokens" in data
|
||||
assert "knowledge_tokens" in data
|
||||
assert "conversation_tokens" in data
|
||||
assert "tool_tokens" in data
|
||||
assert "response_reserve" in data
|
||||
|
||||
# Verify budget allocation makes sense
|
||||
assert data["total_tokens"] > 0
|
||||
total_allocated = (
|
||||
data["system_tokens"]
|
||||
+ data["knowledge_tokens"]
|
||||
+ data["conversation_tokens"]
|
||||
+ data["tool_tokens"]
|
||||
+ data["response_reserve"]
|
||||
)
|
||||
assert total_allocated <= data["total_tokens"]
|
||||
|
||||
async def test_get_token_budget_with_custom_max(self, e2e_client):
|
||||
"""Test getting token budget with custom max tokens."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/context/budget/claude-3-sonnet",
|
||||
params={"max_tokens": 50000},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert data["model"] == "claude-3-sonnet"
|
||||
# Custom max should be respected or capped
|
||||
assert data["total_tokens"] <= 50000
|
||||
|
||||
async def test_count_tokens(self, e2e_client):
|
||||
"""Test token counting endpoint."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={
|
||||
"content": "Hello, this is a test message for token counting.",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "token_count" in data
|
||||
assert data["token_count"] > 0
|
||||
assert data["model"] == "claude-3-sonnet"
|
||||
|
||||
|
||||
class TestAdminMCPOperations:
|
||||
"""Test admin-only MCP operations require authentication."""
|
||||
|
||||
async def test_tool_call_requires_auth(self, e2e_client):
|
||||
"""Test that tool execution requires authentication."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={
|
||||
"server": "llm-gateway",
|
||||
"tool": "count_tokens",
|
||||
"arguments": {"text": "test"},
|
||||
},
|
||||
)
|
||||
|
||||
# Should require authentication
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_circuit_reset_requires_auth(self, e2e_client):
|
||||
"""Test that circuit breaker reset requires authentication."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/mcp/circuit-breakers/llm-gateway/reset"
|
||||
)
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_server_reconnect_requires_auth(self, e2e_client):
|
||||
"""Test that server reconnect requires authentication."""
|
||||
response = await e2e_client.post("/api/v1/mcp/servers/llm-gateway/reconnect")
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_context_stats_requires_auth(self, e2e_client):
|
||||
"""Test that context stats requires authentication."""
|
||||
response = await e2e_client.get("/api/v1/context/stats")
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_context_assemble_requires_auth(self, e2e_client):
|
||||
"""Test that context assembly requires authentication."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test query",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_cache_invalidate_requires_auth(self, e2e_client):
|
||||
"""Test that cache invalidation requires authentication."""
|
||||
response = await e2e_client.post("/api/v1/context/cache/invalidate")
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
|
||||
class TestAdminMCPWithAuthentication:
|
||||
"""Test admin MCP operations with superuser authentication."""
|
||||
|
||||
async def test_superuser_can_get_context_stats(self, e2e_client, e2e_superuser):
|
||||
"""Test that superuser can get context engine stats."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/context/stats",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "cache" in data
|
||||
assert "settings" in data
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Requires MCP servers (llm-gateway, knowledge-base) to be running"
|
||||
)
|
||||
async def test_superuser_can_assemble_context(self, e2e_client, e2e_superuser):
|
||||
"""Test that superuser can assemble context."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/assemble",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"project_id": f"test-project-{uuid4().hex[:8]}",
|
||||
"agent_id": f"test-agent-{uuid4().hex[:8]}",
|
||||
"query": "What is the status of the project?",
|
||||
"model": "claude-3-sonnet",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"compress": True,
|
||||
"use_cache": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "content" in data
|
||||
assert "total_tokens" in data
|
||||
assert "context_count" in data
|
||||
assert "budget_used_percent" in data
|
||||
assert "metadata" in data
|
||||
|
||||
async def test_superuser_can_invalidate_cache(self, e2e_client, e2e_superuser):
|
||||
"""Test that superuser can invalidate cache."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/cache/invalidate",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"project_id": "test-project"},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
async def test_regular_user_cannot_access_admin_operations(self, e2e_client):
|
||||
"""Test that regular (non-superuser) cannot access admin operations."""
|
||||
email = f"regular-{uuid4().hex[:8]}@example.com"
|
||||
password = "RegularUser123!"
|
||||
|
||||
# Register regular user
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Regular",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Login
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Try to access admin endpoint
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/context/stats",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
# Should be forbidden for non-superuser
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestMCPInputValidation:
|
||||
"""Test input validation for MCP endpoints."""
|
||||
|
||||
async def test_server_name_max_length(self, e2e_client):
|
||||
"""Test that server name has max length validation."""
|
||||
long_name = "a" * 100 # Exceeds 64 char limit
|
||||
|
||||
response = await e2e_client.get(f"/api/v1/mcp/servers/{long_name}/tools")
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_server_name_invalid_characters(self, e2e_client):
|
||||
"""Test that server name rejects invalid characters."""
|
||||
invalid_name = "server@name!invalid"
|
||||
|
||||
response = await e2e_client.get(f"/api/v1/mcp/servers/{invalid_name}/tools")
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_token_count_empty_content(self, e2e_client):
|
||||
"""Test token counting with empty content."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={"content": ""},
|
||||
)
|
||||
|
||||
# Empty content is valid, should return 0 tokens
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert data["token_count"] == 0
|
||||
else:
|
||||
# Or it might be rejected as invalid
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestMCPWorkflowIntegration:
|
||||
"""Test complete MCP workflows end-to-end."""
|
||||
|
||||
async def test_discovery_to_budget_workflow(self, e2e_client):
|
||||
"""Test complete workflow: discover servers -> check budget -> ready for use."""
|
||||
# 1. Discover available servers
|
||||
servers_resp = await e2e_client.get("/api/v1/mcp/servers")
|
||||
assert servers_resp.status_code == 200
|
||||
servers = servers_resp.json()["servers"]
|
||||
assert len(servers) > 0
|
||||
|
||||
# 2. Check context engine health
|
||||
health_resp = await e2e_client.get("/api/v1/context/health")
|
||||
assert health_resp.status_code == 200
|
||||
health = health_resp.json()
|
||||
assert health["status"] == "healthy"
|
||||
|
||||
# 3. Get token budget for a model
|
||||
budget_resp = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
|
||||
assert budget_resp.status_code == 200
|
||||
budget = budget_resp.json()
|
||||
|
||||
# 4. Verify system is ready for context assembly
|
||||
assert budget["total_tokens"] > 0
|
||||
assert health["mcp_connected"] is True
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Requires MCP servers (llm-gateway, knowledge-base) to be running"
|
||||
)
|
||||
async def test_full_context_assembly_workflow(self, e2e_client, e2e_superuser):
|
||||
"""Test complete context assembly workflow with superuser."""
|
||||
project_id = f"e2e-project-{uuid4().hex[:8]}"
|
||||
agent_id = f"e2e-agent-{uuid4().hex[:8]}"
|
||||
|
||||
# 1. Check budget before assembly
|
||||
budget_resp = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
|
||||
assert budget_resp.status_code == 200
|
||||
_ = budget_resp.json() # Verify valid response
|
||||
|
||||
# 2. Count tokens in sample content
|
||||
count_resp = await e2e_client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={"content": "This is a test message for context assembly."},
|
||||
)
|
||||
assert count_resp.status_code == 200
|
||||
token_count = count_resp.json()["token_count"]
|
||||
assert token_count > 0
|
||||
|
||||
# 3. Assemble context
|
||||
assemble_resp = await e2e_client.post(
|
||||
"/api/v1/context/assemble",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"query": "Summarize the current project status",
|
||||
"model": "claude-3-sonnet",
|
||||
"system_prompt": "You are a project management assistant.",
|
||||
"task_description": "Generate a status report",
|
||||
"conversation_history": [
|
||||
{"role": "user", "content": "What's the project status?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check the current status.",
|
||||
},
|
||||
],
|
||||
"compress": True,
|
||||
"use_cache": False,
|
||||
},
|
||||
)
|
||||
assert assemble_resp.status_code == 200
|
||||
assembled = assemble_resp.json()
|
||||
|
||||
# 4. Verify assembly results
|
||||
assert assembled["total_tokens"] > 0
|
||||
assert assembled["context_count"] > 0
|
||||
assert assembled["budget_used_percent"] > 0
|
||||
assert assembled["budget_used_percent"] <= 100
|
||||
|
||||
# 5. Get stats to verify the operation was recorded
|
||||
stats_resp = await e2e_client.get(
|
||||
"/api/v1/context/stats",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
assert stats_resp.status_code == 200
|
||||
684
backend/tests/e2e/test_project_workflows.py
Normal file
684
backend/tests/e2e/test_project_workflows.py
Normal file
@@ -0,0 +1,684 @@
|
||||
"""
|
||||
Project and Agent E2E Workflow Tests.
|
||||
|
||||
Tests complete project management workflows with real PostgreSQL:
|
||||
- Project CRUD and lifecycle management
|
||||
- Agent spawning and lifecycle
|
||||
- Issue management within projects
|
||||
- Sprint planning and execution
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from datetime import date, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
class TestProjectCRUDWorkflows:
|
||||
"""Test complete project CRUD workflows."""
|
||||
|
||||
async def test_create_project_workflow(self, e2e_client):
|
||||
"""Test creating a project as authenticated user."""
|
||||
# Register and login
|
||||
email = f"project-owner-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Project",
|
||||
"last_name": "Owner",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Create project
|
||||
project_slug = f"test-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"name": "E2E Test Project",
|
||||
"slug": project_slug,
|
||||
"description": "A project for E2E testing",
|
||||
"autonomy_level": "milestone",
|
||||
},
|
||||
)
|
||||
|
||||
assert create_resp.status_code == 201, f"Failed: {create_resp.text}"
|
||||
project = create_resp.json()
|
||||
assert project["name"] == "E2E Test Project"
|
||||
assert project["slug"] == project_slug
|
||||
assert project["status"] == "active"
|
||||
assert project["agent_count"] == 0
|
||||
assert project["issue_count"] == 0
|
||||
|
||||
async def test_list_projects_only_shows_owned(self, e2e_client):
|
||||
"""Test that users only see their own projects."""
|
||||
# Create two users with projects
|
||||
users = []
|
||||
for i in range(2):
|
||||
email = f"user-{i}-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": f"User{i}",
|
||||
"last_name": "Test",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Each user creates their own project
|
||||
project_slug = f"user{i}-project-{uuid4().hex[:8]}"
|
||||
await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"name": f"User {i} Project",
|
||||
"slug": project_slug,
|
||||
},
|
||||
)
|
||||
users.append({"email": email, "tokens": tokens, "slug": project_slug})
|
||||
|
||||
# User 0 should only see their project
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {users[0]['tokens']['access_token']}"},
|
||||
)
|
||||
assert list_resp.status_code == 200
|
||||
data = list_resp.json()
|
||||
slugs = [p["slug"] for p in data["data"]]
|
||||
assert users[0]["slug"] in slugs
|
||||
assert users[1]["slug"] not in slugs
|
||||
|
||||
async def test_project_lifecycle_pause_resume(self, e2e_client):
|
||||
"""Test pausing and resuming a project."""
|
||||
# Setup user and project
|
||||
email = f"lifecycle-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Lifecycle",
|
||||
"last_name": "Test",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"lifecycle-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Lifecycle Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Pause the project
|
||||
pause_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/pause",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert pause_resp.status_code == 200
|
||||
assert pause_resp.json()["status"] == "paused"
|
||||
|
||||
# Resume the project
|
||||
resume_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/resume",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert resume_resp.status_code == 200
|
||||
assert resume_resp.json()["status"] == "active"
|
||||
|
||||
async def test_project_archive(self, e2e_client):
|
||||
"""Test archiving a project (soft delete)."""
|
||||
# Setup user and project
|
||||
email = f"archive-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Archive",
|
||||
"last_name": "Test",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"archive-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Archive Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Archive the project
|
||||
archive_resp = await e2e_client.delete(
|
||||
f"/api/v1/projects/{project_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert archive_resp.status_code == 200
|
||||
assert archive_resp.json()["success"] is True
|
||||
|
||||
# Verify project is archived
|
||||
get_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert get_resp.status_code == 200
|
||||
assert get_resp.json()["status"] == "archived"
|
||||
|
||||
|
||||
class TestIssueWorkflows:
|
||||
"""Test issue management workflows within projects."""
|
||||
|
||||
async def test_create_and_list_issues(self, e2e_client):
|
||||
"""Test creating and listing issues in a project."""
|
||||
# Setup user and project
|
||||
email = f"issue-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Issue",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"issue-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Issue Test Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create multiple issues
|
||||
issues = []
|
||||
for i in range(3):
|
||||
issue_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": f"Test Issue {i + 1}",
|
||||
"body": f"Description for issue {i + 1}",
|
||||
"priority": ["low", "medium", "high"][i],
|
||||
},
|
||||
)
|
||||
assert issue_resp.status_code == 201, f"Failed: {issue_resp.text}"
|
||||
issues.append(issue_resp.json())
|
||||
|
||||
# List issues
|
||||
list_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert list_resp.status_code == 200
|
||||
data = list_resp.json()
|
||||
assert data["pagination"]["total"] == 3
|
||||
|
||||
async def test_issue_status_transitions(self, e2e_client):
|
||||
"""Test issue status workflow transitions."""
|
||||
# Setup user and project
|
||||
email = f"status-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Status",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"status-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Status Test Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create issue
|
||||
issue_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": "Status Workflow Issue",
|
||||
"body": "Testing status transitions",
|
||||
},
|
||||
)
|
||||
issue = issue_resp.json()
|
||||
issue_id = issue["id"]
|
||||
assert issue["status"] == "open"
|
||||
|
||||
# Transition through statuses
|
||||
for new_status in ["in_progress", "in_review", "closed"]:
|
||||
update_resp = await e2e_client.patch(
|
||||
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"status": new_status},
|
||||
)
|
||||
assert update_resp.status_code == 200, f"Failed: {update_resp.text}"
|
||||
assert update_resp.json()["status"] == new_status
|
||||
|
||||
async def test_issue_filtering(self, e2e_client):
|
||||
"""Test issue filtering by status and priority."""
|
||||
# Setup user and project
|
||||
email = f"filter-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Filter",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"filter-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Filter Test Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create issues with different priorities
|
||||
for priority in ["low", "medium", "high"]:
|
||||
await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": f"{priority.title()} Priority Issue",
|
||||
"priority": priority,
|
||||
},
|
||||
)
|
||||
|
||||
# Filter by high priority
|
||||
filter_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
params={"priority": "high"},
|
||||
)
|
||||
assert filter_resp.status_code == 200
|
||||
data = filter_resp.json()
|
||||
assert data["pagination"]["total"] == 1
|
||||
assert data["data"][0]["priority"] == "high"
|
||||
|
||||
|
||||
class TestSprintWorkflows:
|
||||
"""Test sprint planning and execution workflows."""
|
||||
|
||||
async def test_sprint_lifecycle(self, e2e_client):
|
||||
"""Test complete sprint lifecycle: plan -> start -> complete."""
|
||||
# Setup user and project
|
||||
email = f"sprint-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Sprint",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"sprint-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Sprint Test Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create sprint
|
||||
today = date.today()
|
||||
sprint_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"name": "Sprint 1",
|
||||
"number": 1,
|
||||
"goal": "Complete initial features",
|
||||
"start_date": today.isoformat(),
|
||||
"end_date": (today + timedelta(days=14)).isoformat(),
|
||||
},
|
||||
)
|
||||
assert sprint_resp.status_code == 201, f"Failed: {sprint_resp.text}"
|
||||
sprint = sprint_resp.json()
|
||||
sprint_id = sprint["id"]
|
||||
assert sprint["status"] == "planned"
|
||||
|
||||
# Start sprint
|
||||
start_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/start",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert start_resp.status_code == 200, f"Failed: {start_resp.text}"
|
||||
assert start_resp.json()["status"] == "active"
|
||||
|
||||
# Complete sprint
|
||||
complete_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/complete",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert complete_resp.status_code == 200, f"Failed: {complete_resp.text}"
|
||||
assert complete_resp.json()["status"] == "completed"
|
||||
|
||||
async def test_add_issues_to_sprint(self, e2e_client):
|
||||
"""Test adding issues to a sprint."""
|
||||
# Setup user and project
|
||||
email = f"sprint-issues-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "SprintIssues",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"sprint-issues-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Sprint Issues Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create sprint
|
||||
today = date.today()
|
||||
sprint_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"name": "Sprint 1",
|
||||
"number": 1,
|
||||
"start_date": today.isoformat(),
|
||||
"end_date": (today + timedelta(days=14)).isoformat(),
|
||||
},
|
||||
)
|
||||
assert sprint_resp.status_code == 201, f"Failed: {sprint_resp.text}"
|
||||
sprint = sprint_resp.json()
|
||||
sprint_id = sprint["id"]
|
||||
|
||||
# Create issue
|
||||
issue_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": "Sprint Issue",
|
||||
"story_points": 5,
|
||||
},
|
||||
)
|
||||
issue = issue_resp.json()
|
||||
issue_id = issue["id"]
|
||||
|
||||
# Add issue to sprint
|
||||
add_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
params={"issue_id": issue_id},
|
||||
)
|
||||
assert add_resp.status_code == 200, f"Failed: {add_resp.text}"
|
||||
|
||||
# Verify issue is in sprint
|
||||
issue_check = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert issue_check.json()["sprint_id"] == sprint_id
|
||||
|
||||
|
||||
class TestCrossEntityValidation:
|
||||
"""Test validation across related entities."""
|
||||
|
||||
async def test_cannot_access_other_users_project(self, e2e_client):
|
||||
"""Test that users cannot access projects they don't own."""
|
||||
# Create two users
|
||||
owner_email = f"owner-{uuid4().hex[:8]}@example.com"
|
||||
other_email = f"other-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
# Register owner
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": owner_email,
|
||||
"password": password,
|
||||
"first_name": "Owner",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
owner_tokens = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": owner_email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
# Register other user
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": other_email,
|
||||
"password": password,
|
||||
"first_name": "Other",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
other_tokens = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": other_email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
# Owner creates project
|
||||
project_slug = f"private-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {owner_tokens['access_token']}"},
|
||||
json={"name": "Private Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Other user tries to access
|
||||
access_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}",
|
||||
headers={"Authorization": f"Bearer {other_tokens['access_token']}"},
|
||||
)
|
||||
assert access_resp.status_code == 403
|
||||
|
||||
async def test_duplicate_project_slug_rejected(self, e2e_client):
|
||||
"""Test that duplicate project slugs are rejected."""
|
||||
email = f"dup-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Dup",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
tokens = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
slug = f"unique-slug-{uuid4().hex[:8]}"
|
||||
|
||||
# First creation should succeed
|
||||
resp1 = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "First Project", "slug": slug},
|
||||
)
|
||||
assert resp1.status_code == 201
|
||||
|
||||
# Second creation with same slug should fail
|
||||
resp2 = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Second Project", "slug": slug},
|
||||
)
|
||||
assert resp2.status_code == 409 # Conflict
|
||||
|
||||
|
||||
class TestIssueStats:
|
||||
"""Test issue statistics endpoints."""
|
||||
|
||||
async def test_issue_stats_aggregation(self, e2e_client):
|
||||
"""Test that issue stats are correctly aggregated."""
|
||||
email = f"stats-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Stats",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
tokens = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
project_slug = f"stats-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Stats Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create issues with different priorities and story points
|
||||
await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": "High Priority",
|
||||
"priority": "high",
|
||||
"story_points": 8,
|
||||
},
|
||||
)
|
||||
await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": "Low Priority",
|
||||
"priority": "low",
|
||||
"story_points": 2,
|
||||
},
|
||||
)
|
||||
|
||||
# Get stats
|
||||
stats_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/issues/stats",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert stats_resp.status_code == 200
|
||||
stats = stats_resp.json()
|
||||
assert stats["total"] == 2
|
||||
assert stats["total_story_points"] == 10
|
||||
1
backend/tests/integration/__init__.py
Normal file
1
backend/tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Integration tests that require the full stack to be running."""
|
||||
322
backend/tests/integration/test_mcp_integration.py
Normal file
322
backend/tests/integration/test_mcp_integration.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Integration tests for MCP server connectivity.
|
||||
|
||||
These tests require the full stack to be running:
|
||||
- docker compose -f docker-compose.dev.yml up
|
||||
|
||||
Run with:
|
||||
pytest tests/integration/ -v --integration
|
||||
|
||||
Or skip with:
|
||||
pytest tests/ -v --ignore=tests/integration/
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Skip all tests in this module if not running integration tests
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true",
|
||||
reason="Integration tests require RUN_INTEGRATION_TESTS=true and running stack",
|
||||
)
|
||||
|
||||
|
||||
# Configuration from environment
|
||||
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8001")
|
||||
KNOWLEDGE_BASE_URL = os.getenv("KNOWLEDGE_BASE_URL", "http://localhost:8002")
|
||||
|
||||
|
||||
class TestMCPServerHealth:
|
||||
"""Test that MCP servers are healthy and reachable."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_gateway_health(self) -> None:
|
||||
"""Test LLM Gateway health endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{LLM_GATEWAY_URL}/health", timeout=10.0)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("status") == "healthy" or data.get("healthy") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_base_health(self) -> None:
|
||||
"""Test Knowledge Base health endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{KNOWLEDGE_BASE_URL}/health", timeout=10.0)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("status") == "healthy" or data.get("healthy") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backend_health(self) -> None:
|
||||
"""Test Backend health endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{BACKEND_URL}/health", timeout=10.0)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestMCPClientManagerIntegration:
|
||||
"""Test MCPClientManager can connect to real MCP servers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_servers_list(self) -> None:
|
||||
"""Test that backend can list MCP servers via API."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# This endpoint lists configured MCP servers
|
||||
response = await client.get(
|
||||
f"{BACKEND_URL}/api/v1/mcp/servers",
|
||||
timeout=10.0,
|
||||
)
|
||||
# Should return 200 or 401 (if auth required)
|
||||
assert response.status_code in [200, 401, 403]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_health_check_endpoint(self) -> None:
|
||||
"""Test backend's MCP health check endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{BACKEND_URL}/api/v1/mcp/health",
|
||||
timeout=30.0, # MCP health checks can take time
|
||||
)
|
||||
# Should return 200 or 401 (if auth required)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# Check structure
|
||||
assert "servers" in data or "healthy" in data
|
||||
|
||||
|
||||
class TestLLMGatewayIntegration:
|
||||
"""Test LLM Gateway MCP server functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models(self) -> None:
|
||||
"""Test that LLM Gateway can list available models."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# MCP servers use JSON-RPC 2.0 protocol at /mcp endpoint
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {},
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should have tools listed
|
||||
assert "result" in data or "error" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens(self) -> None:
|
||||
"""Test token counting functionality."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "count_tokens",
|
||||
"arguments": {
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"text": "Hello, world!",
|
||||
},
|
||||
},
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check for result or error
|
||||
if "result" in data:
|
||||
assert "content" in data["result"] or "token_count" in str(
|
||||
data["result"]
|
||||
)
|
||||
|
||||
|
||||
class TestKnowledgeBaseIntegration:
|
||||
"""Test Knowledge Base MCP server functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(self) -> None:
|
||||
"""Test that Knowledge Base can list available tools."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Knowledge Base uses GET /mcp/tools for listing
|
||||
response = await client.get(
|
||||
f"{KNOWLEDGE_BASE_URL}/mcp/tools",
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_knowledge_empty(self) -> None:
|
||||
"""Test search on empty knowledge base."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Knowledge Base uses direct tool name as method
|
||||
response = await client.post(
|
||||
f"{KNOWLEDGE_BASE_URL}/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "search_knowledge",
|
||||
"params": {
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test query",
|
||||
"limit": 5,
|
||||
},
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should return empty results or error for no collection
|
||||
assert "result" in data or "error" in data
|
||||
|
||||
|
||||
class TestEndToEndMCPFlow:
|
||||
"""End-to-end tests for MCP integration flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_mcp_discovery_flow(self) -> None:
|
||||
"""Test the full flow of discovering and listing MCP tools."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# 1. Check backend health
|
||||
health = await client.get(f"{BACKEND_URL}/health", timeout=10.0)
|
||||
assert health.status_code == 200
|
||||
|
||||
# 2. Check LLM Gateway health
|
||||
llm_health = await client.get(f"{LLM_GATEWAY_URL}/health", timeout=10.0)
|
||||
assert llm_health.status_code == 200
|
||||
|
||||
# 3. Check Knowledge Base health
|
||||
kb_health = await client.get(f"{KNOWLEDGE_BASE_URL}/health", timeout=10.0)
|
||||
assert kb_health.status_code == 200
|
||||
|
||||
# 4. List tools from LLM Gateway (uses JSON-RPC at /mcp)
|
||||
llm_tools = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/mcp",
|
||||
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}},
|
||||
timeout=10.0,
|
||||
)
|
||||
assert llm_tools.status_code == 200
|
||||
|
||||
# 5. List tools from Knowledge Base (uses GET /mcp/tools)
|
||||
kb_tools = await client.get(
|
||||
f"{KNOWLEDGE_BASE_URL}/mcp/tools",
|
||||
timeout=10.0,
|
||||
)
|
||||
assert kb_tools.status_code == 200
|
||||
|
||||
|
||||
class TestContextEngineIntegration:
|
||||
"""Test Context Engine integration with MCP servers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_health_endpoint(self) -> None:
|
||||
"""Test context engine health endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{BACKEND_URL}/api/v1/context/health",
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("status") == "healthy"
|
||||
assert "mcp_connected" in data
|
||||
assert "cache_enabled" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_budget_endpoint(self) -> None:
|
||||
"""Test token budget endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{BACKEND_URL}/api/v1/context/budget/claude-3-sonnet",
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_tokens" in data
|
||||
assert "system_tokens" in data
|
||||
assert data.get("model") == "claude-3-sonnet"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_assembly_requires_auth(self) -> None:
|
||||
"""Test that context assembly requires authentication."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{BACKEND_URL}/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test query",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
# Should require auth
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
|
||||
def run_quick_health_check() -> dict[str, Any]:
|
||||
"""
|
||||
Quick synchronous health check for all services.
|
||||
Can be run standalone to verify the stack is up.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
results: dict[str, Any] = {
|
||||
"backend": False,
|
||||
"llm_gateway": False,
|
||||
"knowledge_base": False,
|
||||
}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=5.0) as client:
|
||||
try:
|
||||
r = client.get(f"{BACKEND_URL}/health")
|
||||
results["backend"] = r.status_code == 200
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
r = client.get(f"{LLM_GATEWAY_URL}/health")
|
||||
results["llm_gateway"] = r.status_code == 200
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
r = client.get(f"{KNOWLEDGE_BASE_URL}/health")
|
||||
results["knowledge_base"] = r.status_code == 200
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Checking service health...")
|
||||
results = run_quick_health_check()
|
||||
for service, healthy in results.items():
|
||||
status = "OK" if healthy else "FAILED"
|
||||
print(f" {service}: {status}")
|
||||
|
||||
all_healthy = all(results.values())
|
||||
if all_healthy:
|
||||
print("\nAll services healthy! Run integration tests with:")
|
||||
print(" RUN_INTEGRATION_TESTS=true pytest tests/integration/ -v")
|
||||
else:
|
||||
print("\nSome services are not healthy. Start the stack with:")
|
||||
print(" make dev")
|
||||
@@ -72,7 +72,7 @@ class TestContextSettings:
|
||||
"""Test performance settings."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.max_assembly_time_ms == 100
|
||||
assert settings.max_assembly_time_ms == 2000
|
||||
assert settings.parallel_scoring is True
|
||||
assert settings.max_parallel_scores == 10
|
||||
|
||||
|
||||
@@ -758,3 +758,136 @@ class TestBaseScorer:
|
||||
# Boundaries
|
||||
assert scorer.normalize_score(0.0) == 0.0
|
||||
assert scorer.normalize_score(1.0) == 1.0
|
||||
|
||||
|
||||
class TestCompositeScorerEdgeCases:
|
||||
"""Tests for CompositeScorer edge cases and lock management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_zero_weights(self) -> None:
|
||||
"""Test scoring when all weights are zero."""
|
||||
scorer = CompositeScorer(
|
||||
relevance_weight=0.0,
|
||||
recency_weight=0.0,
|
||||
priority_weight=0.0,
|
||||
)
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
)
|
||||
|
||||
# Should return 0.0 when total weight is 0
|
||||
score = await scorer.score(context, "test query")
|
||||
assert score == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch_sequential(self) -> None:
|
||||
"""Test batch scoring in sequential mode (parallel=False)."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Content 1",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Content 2",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
# Use parallel=False to cover the sequential path
|
||||
scored = await scorer.score_batch(contexts, "query", parallel=False)
|
||||
|
||||
assert len(scored) == 2
|
||||
assert scored[0].relevance_score == 0.8
|
||||
assert scored[1].relevance_score == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_fast_path_reuse(self) -> None:
|
||||
"""Test that existing locks are reused via fast path."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
)
|
||||
|
||||
# First access creates the lock
|
||||
lock1 = await scorer._get_context_lock(context.id)
|
||||
|
||||
# Second access should hit the fast path (lock exists in dict)
|
||||
lock2 = await scorer._get_context_lock(context.id)
|
||||
|
||||
assert lock2 is lock1 # Same lock object returned
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_cleanup_when_limit_reached(self) -> None:
|
||||
"""Test that old locks are cleaned up when limit is reached."""
|
||||
import time
|
||||
|
||||
# Create scorer with very low max_locks to trigger cleanup
|
||||
scorer = CompositeScorer()
|
||||
scorer._max_locks = 3
|
||||
scorer._lock_ttl = 0.1 # 100ms TTL
|
||||
|
||||
# Create locks for several context IDs
|
||||
context_ids = [f"ctx-{i}" for i in range(5)]
|
||||
|
||||
# Get locks for first 3 contexts (fill up to limit)
|
||||
for ctx_id in context_ids[:3]:
|
||||
await scorer._get_context_lock(ctx_id)
|
||||
|
||||
# Wait for TTL to expire
|
||||
time.sleep(0.15)
|
||||
|
||||
# Getting a lock for a new context should trigger cleanup
|
||||
await scorer._get_context_lock(context_ids[3])
|
||||
|
||||
# Some old locks should have been cleaned up
|
||||
# The exact number depends on cleanup logic
|
||||
assert len(scorer._context_locks) <= scorer._max_locks + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_cleanup_preserves_held_locks(self) -> None:
|
||||
"""Test that cleanup doesn't remove locks that are currently held."""
|
||||
import time
|
||||
|
||||
scorer = CompositeScorer()
|
||||
scorer._max_locks = 2
|
||||
scorer._lock_ttl = 0.05 # 50ms TTL
|
||||
|
||||
# Get and hold lock1
|
||||
lock1 = await scorer._get_context_lock("ctx-1")
|
||||
async with lock1:
|
||||
# While holding lock1, add more locks
|
||||
await scorer._get_context_lock("ctx-2")
|
||||
time.sleep(0.1) # Let TTL expire
|
||||
# Adding another should trigger cleanup
|
||||
await scorer._get_context_lock("ctx-3")
|
||||
|
||||
# lock1 should still exist (it's held)
|
||||
assert any(lock is lock1 for lock, _ in scorer._context_locks.values())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_lock_acquisition_double_check(self) -> None:
|
||||
"""Test that concurrent lock acquisition uses double-check pattern."""
|
||||
import asyncio
|
||||
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context_id = "test-context-id"
|
||||
|
||||
# Simulate concurrent lock acquisition
|
||||
async def get_lock():
|
||||
return await scorer._get_context_lock(context_id)
|
||||
|
||||
locks = await asyncio.gather(*[get_lock() for _ in range(10)])
|
||||
|
||||
# All should get the same lock (double-check pattern ensures this)
|
||||
assert all(lock is locks[0] for lock in locks)
|
||||
|
||||
989
backend/tests/services/safety/test_audit.py
Normal file
989
backend/tests/services/safety/test_audit.py
Normal file
@@ -0,0 +1,989 @@
|
||||
"""
|
||||
Tests for Audit Logger.
|
||||
|
||||
Tests cover:
|
||||
- AuditLogger initialization and lifecycle
|
||||
- Event logging and hash chain
|
||||
- Query and filtering
|
||||
- Retention policy enforcement
|
||||
- Handler management
|
||||
- Singleton pattern
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.services.safety.audit.logger import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
AuditEventType,
|
||||
AutonomyLevel,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
|
||||
class TestAuditLoggerInit:
|
||||
"""Tests for AuditLogger initialization."""
|
||||
|
||||
def test_init_default_values(self):
|
||||
"""Test initialization with default values."""
|
||||
logger = AuditLogger()
|
||||
|
||||
assert logger._flush_interval == 10.0
|
||||
assert logger._enable_hash_chain is True
|
||||
assert logger._last_hash is None
|
||||
assert logger._running is False
|
||||
|
||||
def test_init_custom_values(self):
|
||||
"""Test initialization with custom values."""
|
||||
logger = AuditLogger(
|
||||
max_buffer_size=500,
|
||||
flush_interval_seconds=5.0,
|
||||
enable_hash_chain=False,
|
||||
)
|
||||
|
||||
assert logger._flush_interval == 5.0
|
||||
assert logger._enable_hash_chain is False
|
||||
|
||||
|
||||
class TestAuditLoggerLifecycle:
|
||||
"""Tests for AuditLogger start/stop."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_flush_task(self):
|
||||
"""Test that start creates the periodic flush task."""
|
||||
logger = AuditLogger(flush_interval_seconds=1.0)
|
||||
|
||||
await logger.start()
|
||||
|
||||
assert logger._running is True
|
||||
assert logger._flush_task is not None
|
||||
|
||||
await logger.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_idempotent(self):
|
||||
"""Test that multiple starts don't create multiple tasks."""
|
||||
logger = AuditLogger()
|
||||
|
||||
await logger.start()
|
||||
task1 = logger._flush_task
|
||||
|
||||
await logger.start() # Second start
|
||||
task2 = logger._flush_task
|
||||
|
||||
assert task1 is task2
|
||||
|
||||
await logger.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_task_and_flushes(self):
|
||||
"""Test that stop cancels the task and flushes events."""
|
||||
logger = AuditLogger()
|
||||
|
||||
await logger.start()
|
||||
|
||||
# Add an event
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED, agent_id="agent-1")
|
||||
|
||||
await logger.stop()
|
||||
|
||||
assert logger._running is False
|
||||
# Event should be flushed
|
||||
assert len(logger._persisted) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_without_start(self):
|
||||
"""Test stopping without starting doesn't error."""
|
||||
logger = AuditLogger()
|
||||
await logger.stop() # Should not raise
|
||||
|
||||
|
||||
class TestAuditLoggerLog:
|
||||
"""Tests for the log method."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def logger(self):
|
||||
"""Create a logger instance."""
|
||||
return AuditLogger(enable_hash_chain=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_creates_event(self, logger):
|
||||
"""Test logging creates an event."""
|
||||
event = await logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
agent_id="agent-1",
|
||||
project_id="proj-1",
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.ACTION_REQUESTED
|
||||
assert event.agent_id == "agent-1"
|
||||
assert event.project_id == "proj-1"
|
||||
assert event.id is not None
|
||||
assert event.timestamp is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_adds_hash_chain(self, logger):
|
||||
"""Test logging adds hash chain."""
|
||||
event = await logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert "_hash" in event.details
|
||||
assert "_prev_hash" in event.details
|
||||
assert event.details["_prev_hash"] is None # First event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_chain_links_events(self, logger):
|
||||
"""Test hash chain links events."""
|
||||
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
event2 = await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||
|
||||
assert event2.details["_prev_hash"] == event1.details["_hash"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_without_hash_chain(self):
|
||||
"""Test logging without hash chain."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
event = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
assert "_hash" not in event.details
|
||||
assert "_prev_hash" not in event.details
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_with_all_fields(self, logger):
|
||||
"""Test logging with all optional fields."""
|
||||
event = await logger.log(
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
agent_id="agent-1",
|
||||
action_id="action-1",
|
||||
project_id="proj-1",
|
||||
session_id="sess-1",
|
||||
user_id="user-1",
|
||||
decision=SafetyDecision.ALLOW,
|
||||
details={"custom": "data"},
|
||||
correlation_id="corr-1",
|
||||
)
|
||||
|
||||
assert event.agent_id == "agent-1"
|
||||
assert event.action_id == "action-1"
|
||||
assert event.project_id == "proj-1"
|
||||
assert event.session_id == "sess-1"
|
||||
assert event.user_id == "user-1"
|
||||
assert event.decision == SafetyDecision.ALLOW
|
||||
assert event.details["custom"] == "data"
|
||||
assert event.correlation_id == "corr-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_buffers_event(self, logger):
|
||||
"""Test logging adds event to buffer."""
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
assert len(logger._buffer) == 1
|
||||
|
||||
|
||||
class TestAuditLoggerConvenienceMethods:
|
||||
"""Tests for convenience logging methods."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def logger(self):
|
||||
"""Create a logger instance."""
|
||||
return AuditLogger(enable_hash_chain=False)
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
def action(self):
|
||||
"""Create a test action request."""
|
||||
metadata = ActionMetadata(
|
||||
agent_id="agent-1",
|
||||
session_id="sess-1",
|
||||
project_id="proj-1",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
user_id="user-1",
|
||||
correlation_id="corr-1",
|
||||
)
|
||||
|
||||
return ActionRequest(
|
||||
action_type=ActionType.FILE_WRITE,
|
||||
tool_name="file_write",
|
||||
arguments={"path": "/test.txt"},
|
||||
resource="/test.txt",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_action_request_allowed(self, logger, action):
|
||||
"""Test logging allowed action request."""
|
||||
event = await logger.log_action_request(
|
||||
action,
|
||||
SafetyDecision.ALLOW,
|
||||
reasons=["Within budget"],
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.ACTION_VALIDATED
|
||||
assert event.decision == SafetyDecision.ALLOW
|
||||
assert event.details["reasons"] == ["Within budget"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_action_request_denied(self, logger, action):
|
||||
"""Test logging denied action request."""
|
||||
event = await logger.log_action_request(
|
||||
action,
|
||||
SafetyDecision.DENY,
|
||||
reasons=["Rate limit exceeded"],
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.ACTION_DENIED
|
||||
assert event.decision == SafetyDecision.DENY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_action_executed_success(self, logger, action):
|
||||
"""Test logging successful action execution."""
|
||||
event = await logger.log_action_executed(
|
||||
action,
|
||||
success=True,
|
||||
execution_time_ms=50.0,
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.ACTION_EXECUTED
|
||||
assert event.details["success"] is True
|
||||
assert event.details["execution_time_ms"] == 50.0
|
||||
assert event.details["error"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_action_executed_failure(self, logger, action):
|
||||
"""Test logging failed action execution."""
|
||||
event = await logger.log_action_executed(
|
||||
action,
|
||||
success=False,
|
||||
execution_time_ms=100.0,
|
||||
error="File not found",
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.ACTION_FAILED
|
||||
assert event.details["success"] is False
|
||||
assert event.details["error"] == "File not found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_approval_event(self, logger, action):
|
||||
"""Test logging approval event."""
|
||||
event = await logger.log_approval_event(
|
||||
AuditEventType.APPROVAL_GRANTED,
|
||||
approval_id="approval-1",
|
||||
action=action,
|
||||
decided_by="admin",
|
||||
reason="Approved by admin",
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.APPROVAL_GRANTED
|
||||
assert event.details["approval_id"] == "approval-1"
|
||||
assert event.details["decided_by"] == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_budget_event(self, logger):
|
||||
"""Test logging budget event."""
|
||||
event = await logger.log_budget_event(
|
||||
AuditEventType.BUDGET_WARNING,
|
||||
agent_id="agent-1",
|
||||
scope="daily",
|
||||
current_usage=8000.0,
|
||||
limit=10000.0,
|
||||
unit="tokens",
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.BUDGET_WARNING
|
||||
assert event.details["scope"] == "daily"
|
||||
assert event.details["usage_percent"] == 80.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_budget_event_zero_limit(self, logger):
|
||||
"""Test logging budget event with zero limit."""
|
||||
event = await logger.log_budget_event(
|
||||
AuditEventType.BUDGET_WARNING,
|
||||
agent_id="agent-1",
|
||||
scope="daily",
|
||||
current_usage=100.0,
|
||||
limit=0.0, # Zero limit
|
||||
)
|
||||
|
||||
assert event.details["usage_percent"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_emergency_stop(self, logger):
|
||||
"""Test logging emergency stop."""
|
||||
event = await logger.log_emergency_stop(
|
||||
stop_type="global",
|
||||
triggered_by="admin",
|
||||
reason="Security incident",
|
||||
affected_agents=["agent-1", "agent-2"],
|
||||
)
|
||||
|
||||
assert event.event_type == AuditEventType.EMERGENCY_STOP
|
||||
assert event.details["stop_type"] == "global"
|
||||
assert event.details["affected_agents"] == ["agent-1", "agent-2"]
|
||||
|
||||
|
||||
class TestAuditLoggerFlush:
|
||||
"""Tests for flush functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_persists_events(self):
|
||||
"""Test flush moves events to persisted storage."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||
|
||||
assert len(logger._buffer) == 2
|
||||
assert len(logger._persisted) == 0
|
||||
|
||||
count = await logger.flush()
|
||||
|
||||
assert count == 2
|
||||
assert len(logger._buffer) == 0
|
||||
assert len(logger._persisted) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_empty_buffer(self):
|
||||
"""Test flush with empty buffer."""
|
||||
logger = AuditLogger()
|
||||
|
||||
count = await logger.flush()
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestAuditLoggerQuery:
|
||||
"""Tests for query functionality."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def logger_with_events(self):
|
||||
"""Create a logger with some test events."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
# Add various events
|
||||
await logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
agent_id="agent-1",
|
||||
project_id="proj-1",
|
||||
)
|
||||
await logger.log(
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
agent_id="agent-1",
|
||||
project_id="proj-1",
|
||||
)
|
||||
await logger.log(
|
||||
AuditEventType.ACTION_DENIED,
|
||||
agent_id="agent-2",
|
||||
project_id="proj-2",
|
||||
)
|
||||
await logger.log(
|
||||
AuditEventType.BUDGET_WARNING,
|
||||
agent_id="agent-1",
|
||||
project_id="proj-1",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_all(self, logger_with_events):
|
||||
"""Test querying all events."""
|
||||
events = await logger_with_events.query()
|
||||
|
||||
assert len(events) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_event_type(self, logger_with_events):
|
||||
"""Test filtering by event type."""
|
||||
events = await logger_with_events.query(
|
||||
event_types=[AuditEventType.ACTION_REQUESTED]
|
||||
)
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == AuditEventType.ACTION_REQUESTED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_agent_id(self, logger_with_events):
|
||||
"""Test filtering by agent ID."""
|
||||
events = await logger_with_events.query(agent_id="agent-1")
|
||||
|
||||
assert len(events) == 3
|
||||
assert all(e.agent_id == "agent-1" for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_project_id(self, logger_with_events):
|
||||
"""Test filtering by project ID."""
|
||||
events = await logger_with_events.query(project_id="proj-2")
|
||||
|
||||
assert len(events) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_user_id(self, logger_with_events):
|
||||
"""Test filtering by user ID."""
|
||||
events = await logger_with_events.query(user_id="user-1")
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == AuditEventType.BUDGET_WARNING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_limit(self, logger_with_events):
|
||||
"""Test query with limit."""
|
||||
events = await logger_with_events.query(limit=2)
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_offset(self, logger_with_events):
|
||||
"""Test query with offset."""
|
||||
all_events = await logger_with_events.query()
|
||||
offset_events = await logger_with_events.query(offset=2)
|
||||
|
||||
assert len(offset_events) == 2
|
||||
assert offset_events[0] == all_events[2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_time_range(self):
|
||||
"""Test filtering by time range."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
now = datetime.utcnow()
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
# Query with start time
|
||||
events = await logger.query(
|
||||
start_time=now - timedelta(seconds=1),
|
||||
end_time=now + timedelta(seconds=1),
|
||||
)
|
||||
|
||||
assert len(events) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_correlation_id(self):
|
||||
"""Test filtering by correlation ID."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
correlation_id="corr-123",
|
||||
)
|
||||
await logger.log(
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
correlation_id="corr-456",
|
||||
)
|
||||
|
||||
events = await logger.query(correlation_id="corr-123")
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].correlation_id == "corr-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_combined_filters(self, logger_with_events):
|
||||
"""Test combined filters."""
|
||||
events = await logger_with_events.query(
|
||||
agent_id="agent-1",
|
||||
project_id="proj-1",
|
||||
event_types=[
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
],
|
||||
)
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_action_history(self, logger_with_events):
|
||||
"""Test get_action_history method."""
|
||||
events = await logger_with_events.get_action_history("agent-1")
|
||||
|
||||
# Should only return action-related events
|
||||
assert len(events) == 2
|
||||
assert all(
|
||||
e.event_type
|
||||
in {AuditEventType.ACTION_REQUESTED, AuditEventType.ACTION_EXECUTED}
|
||||
for e in events
|
||||
)
|
||||
|
||||
|
||||
class TestAuditLoggerIntegrity:
|
||||
"""Tests for hash chain integrity verification."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_integrity_valid(self):
|
||||
"""Test integrity verification with valid chain."""
|
||||
logger = AuditLogger(enable_hash_chain=True)
|
||||
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||
|
||||
is_valid, issues = await logger.verify_integrity()
|
||||
|
||||
assert is_valid is True
|
||||
assert len(issues) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_integrity_disabled(self):
|
||||
"""Test integrity verification when hash chain disabled."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
is_valid, issues = await logger.verify_integrity()
|
||||
|
||||
assert is_valid is True
|
||||
assert len(issues) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_integrity_broken_chain(self):
|
||||
"""Test integrity verification with broken chain."""
|
||||
logger = AuditLogger(enable_hash_chain=True)
|
||||
|
||||
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||
|
||||
# Tamper with first event's hash
|
||||
event1.details["_hash"] = "tampered_hash"
|
||||
|
||||
is_valid, issues = await logger.verify_integrity()
|
||||
|
||||
assert is_valid is False
|
||||
assert len(issues) > 0
|
||||
|
||||
|
||||
class TestAuditLoggerHandlers:
|
||||
"""Tests for event handler management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_sync_handler(self):
|
||||
"""Test adding synchronous handler."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
events_received = []
|
||||
|
||||
def handler(event):
|
||||
events_received.append(event)
|
||||
|
||||
logger.add_handler(handler)
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
assert len(events_received) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_async_handler(self):
|
||||
"""Test adding async handler."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
events_received = []
|
||||
|
||||
async def handler(event):
|
||||
events_received.append(event)
|
||||
|
||||
logger.add_handler(handler)
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
assert len(events_received) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_handler(self):
|
||||
"""Test removing handler."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
events_received = []
|
||||
|
||||
def handler(event):
|
||||
events_received.append(event)
|
||||
|
||||
logger.add_handler(handler)
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
logger.remove_handler(handler)
|
||||
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||
|
||||
assert len(events_received) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_error_caught(self):
|
||||
"""Test that handler errors are caught."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
def failing_handler(event):
|
||||
raise ValueError("Handler error")
|
||||
|
||||
logger.add_handler(failing_handler)
|
||||
|
||||
# Should not raise
|
||||
event = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
assert event is not None
|
||||
|
||||
|
||||
class TestAuditLoggerSanitization:
|
||||
"""Tests for sensitive data sanitization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sanitize_sensitive_keys(self):
|
||||
"""Test sanitization of sensitive keys."""
|
||||
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.audit_retention_days = 30
|
||||
mock_cfg.audit_include_sensitive = False
|
||||
mock_config.return_value = mock_cfg
|
||||
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
event = await logger.log(
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
details={
|
||||
"password": "secret123",
|
||||
"api_key": "key123",
|
||||
"token": "token123",
|
||||
"normal_field": "visible",
|
||||
},
|
||||
)
|
||||
|
||||
assert event.details["password"] == "[REDACTED]"
|
||||
assert event.details["api_key"] == "[REDACTED]"
|
||||
assert event.details["token"] == "[REDACTED]"
|
||||
assert event.details["normal_field"] == "visible"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sanitize_nested_dict(self):
|
||||
"""Test sanitization of nested dictionaries."""
|
||||
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.audit_retention_days = 30
|
||||
mock_cfg.audit_include_sensitive = False
|
||||
mock_config.return_value = mock_cfg
|
||||
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
event = await logger.log(
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
details={
|
||||
"config": {
|
||||
"api_secret": "secret",
|
||||
"name": "test",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert event.details["config"]["api_secret"] == "[REDACTED]"
|
||||
assert event.details["config"]["name"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_include_sensitive_when_enabled(self):
|
||||
"""Test sensitive data included when enabled."""
|
||||
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.audit_retention_days = 30
|
||||
mock_cfg.audit_include_sensitive = True
|
||||
mock_config.return_value = mock_cfg
|
||||
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
event = await logger.log(
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
details={"password": "secret123"},
|
||||
)
|
||||
|
||||
assert event.details["password"] == "secret123"
|
||||
|
||||
|
||||
class TestAuditLoggerRetention:
|
||||
"""Tests for retention policy enforcement."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retention_removes_old_events(self):
|
||||
"""Test that retention removes old events."""
|
||||
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.audit_retention_days = 7
|
||||
mock_cfg.audit_include_sensitive = False
|
||||
mock_config.return_value = mock_cfg
|
||||
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
# Add an old event directly to persisted
|
||||
from app.services.safety.models import AuditEvent
|
||||
|
||||
old_event = AuditEvent(
|
||||
id="old-event",
|
||||
event_type=AuditEventType.ACTION_REQUESTED,
|
||||
timestamp=datetime.utcnow() - timedelta(days=10),
|
||||
details={},
|
||||
)
|
||||
logger._persisted.append(old_event)
|
||||
|
||||
# Add a recent event
|
||||
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||
|
||||
# Flush will trigger retention enforcement
|
||||
await logger.flush()
|
||||
|
||||
# Old event should be removed
|
||||
assert len(logger._persisted) == 1
|
||||
assert logger._persisted[0].id != "old-event"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retention_keeps_recent_events(self):
|
||||
"""Test that retention keeps recent events."""
|
||||
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.audit_retention_days = 7
|
||||
mock_cfg.audit_include_sensitive = False
|
||||
mock_config.return_value = mock_cfg
|
||||
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||
|
||||
await logger.flush()
|
||||
|
||||
assert len(logger._persisted) == 2
|
||||
|
||||
|
||||
class TestAuditLoggerSingleton:
|
||||
"""Tests for singleton pattern."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_audit_logger_creates_instance(self):
|
||||
"""Test get_audit_logger creates singleton."""
|
||||
|
||||
reset_audit_logger()
|
||||
|
||||
logger1 = await get_audit_logger()
|
||||
logger2 = await get_audit_logger()
|
||||
|
||||
assert logger1 is logger2
|
||||
|
||||
await shutdown_audit_logger()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_audit_logger(self):
|
||||
"""Test shutdown_audit_logger stops and clears singleton."""
|
||||
import app.services.safety.audit.logger as audit_module
|
||||
|
||||
reset_audit_logger()
|
||||
|
||||
_logger = await get_audit_logger()
|
||||
await shutdown_audit_logger()
|
||||
|
||||
assert audit_module._audit_logger is None
|
||||
|
||||
def test_reset_audit_logger(self):
|
||||
"""Test reset_audit_logger clears singleton."""
|
||||
import app.services.safety.audit.logger as audit_module
|
||||
|
||||
audit_module._audit_logger = AuditLogger()
|
||||
reset_audit_logger()
|
||||
|
||||
assert audit_module._audit_logger is None
|
||||
|
||||
|
||||
class TestAuditLoggerPeriodicFlush:
|
||||
"""Tests for periodic flush background task."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_periodic_flush_runs(self):
|
||||
"""Test periodic flush runs and flushes events."""
|
||||
logger = AuditLogger(flush_interval_seconds=0.1, enable_hash_chain=False)
|
||||
|
||||
await logger.start()
|
||||
|
||||
# Log an event
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
assert len(logger._buffer) == 1
|
||||
|
||||
# Wait for periodic flush
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
# Event should be flushed
|
||||
assert len(logger._buffer) == 0
|
||||
assert len(logger._persisted) == 1
|
||||
|
||||
await logger.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_periodic_flush_handles_errors(self):
|
||||
"""Test periodic flush handles errors gracefully."""
|
||||
logger = AuditLogger(flush_interval_seconds=0.1)
|
||||
|
||||
await logger.start()
|
||||
|
||||
# Mock flush to raise an error
|
||||
original_flush = logger.flush
|
||||
|
||||
async def failing_flush():
|
||||
raise Exception("Flush error")
|
||||
|
||||
logger.flush = failing_flush
|
||||
|
||||
# Wait for flush attempt
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
# Should still be running
|
||||
assert logger._running is True
|
||||
|
||||
logger.flush = original_flush
|
||||
await logger.stop()
|
||||
|
||||
|
||||
class TestAuditLoggerLogging:
|
||||
"""Tests for standard logger output."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_warning_for_denied(self):
|
||||
"""Test warning level for denied events."""
|
||||
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
||||
audit_logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await audit_logger.log(
|
||||
AuditEventType.ACTION_DENIED,
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_error_for_failed(self):
|
||||
"""Test error level for failed events."""
|
||||
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
||||
audit_logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await audit_logger.log(
|
||||
AuditEventType.ACTION_FAILED,
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
mock_logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_info_for_normal(self):
|
||||
"""Test info level for normal events."""
|
||||
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
||||
audit_logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await audit_logger.log(
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
|
||||
|
||||
class TestAuditLoggerEdgeCases:
|
||||
"""Tests for edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_with_none_details(self):
|
||||
"""Test logging with None details."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
event = await logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
details=None,
|
||||
)
|
||||
|
||||
assert event.details == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_action_id(self):
|
||||
"""Test querying by action ID."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
action_id="action-1",
|
||||
)
|
||||
await logger.log(
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
action_id="action-2",
|
||||
)
|
||||
|
||||
events = await logger.query(action_id="action-1")
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].action_id == "action-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_session_id(self):
|
||||
"""Test querying by session ID."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
session_id="sess-1",
|
||||
)
|
||||
await logger.log(
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
session_id="sess-2",
|
||||
)
|
||||
|
||||
events = await logger.query(session_id="sess-1")
|
||||
|
||||
assert len(events) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_includes_buffer_and_persisted(self):
|
||||
"""Test query includes both buffer and persisted events."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
# Add event to buffer
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
# Flush to persisted
|
||||
await logger.flush()
|
||||
|
||||
# Add another to buffer
|
||||
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||
|
||||
# Query should return both
|
||||
events = await logger.query()
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_nonexistent_handler(self):
|
||||
"""Test removing handler that doesn't exist."""
|
||||
logger = AuditLogger()
|
||||
|
||||
def handler(event):
|
||||
pass
|
||||
|
||||
# Should not raise
|
||||
logger.remove_handler(handler)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_time_filter_excludes_events(self):
|
||||
"""Test time filters exclude events correctly."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
# Query with future start time
|
||||
future = datetime.utcnow() + timedelta(hours=1)
|
||||
events = await logger.query(start_time=future)
|
||||
|
||||
assert len(events) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_end_time_filter(self):
|
||||
"""Test end time filter."""
|
||||
logger = AuditLogger(enable_hash_chain=False)
|
||||
|
||||
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||
|
||||
# Query with past end time
|
||||
past = datetime.utcnow() - timedelta(hours=1)
|
||||
events = await logger.query(end_time=past)
|
||||
|
||||
assert len(events) == 0
|
||||
1136
backend/tests/services/safety/test_hitl.py
Normal file
1136
backend/tests/services/safety/test_hitl.py
Normal file
File diff suppressed because it is too large
Load Diff
874
backend/tests/services/safety/test_mcp_integration.py
Normal file
874
backend/tests/services/safety/test_mcp_integration.py
Normal file
@@ -0,0 +1,874 @@
|
||||
"""
|
||||
Tests for MCP Safety Integration.
|
||||
|
||||
Tests cover:
|
||||
- MCPToolCall and MCPToolResult data structures
|
||||
- MCPSafetyWrapper: tool registration, execution, safety checks
|
||||
- Tool classification and action type mapping
|
||||
- SafeToolExecutor context manager
|
||||
- Factory function create_mcp_wrapper
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.services.safety.exceptions import EmergencyStopError
|
||||
from app.services.safety.mcp.integration import (
|
||||
MCPSafetyWrapper,
|
||||
MCPToolCall,
|
||||
MCPToolResult,
|
||||
SafeToolExecutor,
|
||||
create_mcp_wrapper,
|
||||
)
|
||||
from app.services.safety.models import (
|
||||
ActionType,
|
||||
AutonomyLevel,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
|
||||
class TestMCPToolCall:
|
||||
"""Tests for MCPToolCall dataclass."""
|
||||
|
||||
def test_tool_call_creation(self):
|
||||
"""Test creating a tool call."""
|
||||
call = MCPToolCall(
|
||||
tool_name="file_read",
|
||||
arguments={"path": "/tmp/test.txt"}, # noqa: S108
|
||||
server_name="file-server",
|
||||
project_id="proj-1",
|
||||
context={"session_id": "sess-1"},
|
||||
)
|
||||
|
||||
assert call.tool_name == "file_read"
|
||||
assert call.arguments == {"path": "/tmp/test.txt"} # noqa: S108
|
||||
assert call.server_name == "file-server"
|
||||
assert call.project_id == "proj-1"
|
||||
assert call.context == {"session_id": "sess-1"}
|
||||
|
||||
def test_tool_call_defaults(self):
|
||||
"""Test tool call default values."""
|
||||
call = MCPToolCall(
|
||||
tool_name="test",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
assert call.server_name is None
|
||||
assert call.project_id is None
|
||||
assert call.context == {}
|
||||
|
||||
|
||||
class TestMCPToolResult:
|
||||
"""Tests for MCPToolResult dataclass."""
|
||||
|
||||
def test_tool_result_success(self):
|
||||
"""Test creating a successful result."""
|
||||
result = MCPToolResult(
|
||||
success=True,
|
||||
result={"data": "test"},
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=50.0,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == {"data": "test"}
|
||||
assert result.error is None
|
||||
assert result.safety_decision == SafetyDecision.ALLOW
|
||||
assert result.execution_time_ms == 50.0
|
||||
|
||||
def test_tool_result_failure(self):
|
||||
"""Test creating a failed result."""
|
||||
result = MCPToolResult(
|
||||
success=False,
|
||||
error="Permission denied",
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Permission denied"
|
||||
assert result.result is None
|
||||
|
||||
def test_tool_result_with_ids(self):
|
||||
"""Test result with approval and checkpoint IDs."""
|
||||
result = MCPToolResult(
|
||||
success=True,
|
||||
approval_id="approval-123",
|
||||
checkpoint_id="checkpoint-456",
|
||||
)
|
||||
|
||||
assert result.approval_id == "approval-123"
|
||||
assert result.checkpoint_id == "checkpoint-456"
|
||||
|
||||
def test_tool_result_defaults(self):
|
||||
"""Test result default values."""
|
||||
result = MCPToolResult(success=True)
|
||||
|
||||
assert result.result is None
|
||||
assert result.error is None
|
||||
assert result.safety_decision == SafetyDecision.ALLOW
|
||||
assert result.execution_time_ms == 0.0
|
||||
assert result.approval_id is None
|
||||
assert result.checkpoint_id is None
|
||||
assert result.metadata == {}
|
||||
|
||||
|
||||
class TestMCPSafetyWrapperClassification:
|
||||
"""Tests for tool classification."""
|
||||
|
||||
def test_classify_file_read(self):
|
||||
"""Test classifying file read tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("file_read") == ActionType.FILE_READ
|
||||
assert wrapper._classify_tool("get_file") == ActionType.FILE_READ
|
||||
assert wrapper._classify_tool("list_files") == ActionType.FILE_READ
|
||||
assert wrapper._classify_tool("search_file") == ActionType.FILE_READ
|
||||
|
||||
def test_classify_file_write(self):
|
||||
"""Test classifying file write tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("file_write") == ActionType.FILE_WRITE
|
||||
assert wrapper._classify_tool("create_file") == ActionType.FILE_WRITE
|
||||
assert wrapper._classify_tool("update_file") == ActionType.FILE_WRITE
|
||||
|
||||
def test_classify_file_delete(self):
|
||||
"""Test classifying file delete tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("file_delete") == ActionType.FILE_DELETE
|
||||
assert wrapper._classify_tool("remove_file") == ActionType.FILE_DELETE
|
||||
|
||||
def test_classify_database_read(self):
|
||||
"""Test classifying database read tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("database_query") == ActionType.DATABASE_QUERY
|
||||
assert wrapper._classify_tool("db_read") == ActionType.DATABASE_QUERY
|
||||
assert wrapper._classify_tool("query_database") == ActionType.DATABASE_QUERY
|
||||
|
||||
def test_classify_database_mutate(self):
|
||||
"""Test classifying database mutate tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("database_write") == ActionType.DATABASE_MUTATE
|
||||
assert wrapper._classify_tool("db_update") == ActionType.DATABASE_MUTATE
|
||||
assert wrapper._classify_tool("database_delete") == ActionType.DATABASE_MUTATE
|
||||
|
||||
def test_classify_shell_command(self):
|
||||
"""Test classifying shell command tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("shell_execute") == ActionType.SHELL_COMMAND
|
||||
assert wrapper._classify_tool("exec_command") == ActionType.SHELL_COMMAND
|
||||
assert wrapper._classify_tool("bash_run") == ActionType.SHELL_COMMAND
|
||||
|
||||
def test_classify_git_operation(self):
|
||||
"""Test classifying git tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("git_commit") == ActionType.GIT_OPERATION
|
||||
assert wrapper._classify_tool("git_push") == ActionType.GIT_OPERATION
|
||||
assert wrapper._classify_tool("git_status") == ActionType.GIT_OPERATION
|
||||
|
||||
def test_classify_network_request(self):
|
||||
"""Test classifying network tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("http_get") == ActionType.NETWORK_REQUEST
|
||||
assert wrapper._classify_tool("fetch_url") == ActionType.NETWORK_REQUEST
|
||||
assert wrapper._classify_tool("api_request") == ActionType.NETWORK_REQUEST
|
||||
|
||||
def test_classify_llm_call(self):
|
||||
"""Test classifying LLM tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("llm_generate") == ActionType.LLM_CALL
|
||||
assert wrapper._classify_tool("ai_complete") == ActionType.LLM_CALL
|
||||
assert wrapper._classify_tool("claude_chat") == ActionType.LLM_CALL
|
||||
|
||||
def test_classify_default(self):
|
||||
"""Test default classification for unknown tools."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
assert wrapper._classify_tool("unknown_tool") == ActionType.TOOL_CALL
|
||||
assert wrapper._classify_tool("custom_action") == ActionType.TOOL_CALL
|
||||
|
||||
|
||||
class TestMCPSafetyWrapperToolHandlers:
|
||||
"""Tests for tool handler registration."""
|
||||
|
||||
def test_register_tool_handler(self):
|
||||
"""Test registering a tool handler."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
def handler(path: str) -> str:
|
||||
return f"Read: {path}"
|
||||
|
||||
wrapper.register_tool_handler("file_read", handler)
|
||||
|
||||
assert "file_read" in wrapper._tool_handlers
|
||||
assert wrapper._tool_handlers["file_read"] is handler
|
||||
|
||||
def test_register_multiple_handlers(self):
|
||||
"""Test registering multiple handlers."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
wrapper.register_tool_handler("tool1", lambda: None)
|
||||
wrapper.register_tool_handler("tool2", lambda: None)
|
||||
wrapper.register_tool_handler("tool3", lambda: None)
|
||||
|
||||
assert len(wrapper._tool_handlers) == 3
|
||||
|
||||
def test_overwrite_handler(self):
|
||||
"""Test overwriting a handler."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
handler1 = lambda: "first" # noqa: E731
|
||||
handler2 = lambda: "second" # noqa: E731
|
||||
|
||||
wrapper.register_tool_handler("tool", handler1)
|
||||
wrapper.register_tool_handler("tool", handler2)
|
||||
|
||||
assert wrapper._tool_handlers["tool"] is handler2
|
||||
|
||||
|
||||
class TestMCPSafetyWrapperExecution:
|
||||
"""Tests for tool execution."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_guardian(self):
|
||||
"""Create a mock SafetyGuardian."""
|
||||
guardian = AsyncMock()
|
||||
guardian.validate = AsyncMock()
|
||||
return guardian
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_emergency(self):
|
||||
"""Create a mock EmergencyControls."""
|
||||
emergency = AsyncMock()
|
||||
emergency.check_allowed = AsyncMock()
|
||||
return emergency
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_allowed(self, mock_guardian, mock_emergency):
|
||||
"""Test executing an allowed tool call."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
async def handler(path: str) -> dict:
|
||||
return {"content": f"Data from {path}"}
|
||||
|
||||
wrapper.register_tool_handler("file_read", handler)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_read",
|
||||
arguments={"path": "/test.txt"},
|
||||
project_id="proj-1",
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == {"content": "Data from /test.txt"}
|
||||
assert result.safety_decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_denied(self, mock_guardian, mock_emergency):
|
||||
"""Test executing a denied tool call."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=["Permission denied", "Rate limit exceeded"],
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_write",
|
||||
arguments={"path": "/etc/passwd"},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert "Permission denied" in result.error
|
||||
assert "Rate limit exceeded" in result.error
|
||||
assert result.safety_decision == SafetyDecision.DENY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_requires_approval(self, mock_guardian, mock_emergency):
|
||||
"""Test executing a tool that requires approval."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=["Destructive operation requires approval"],
|
||||
approval_id="approval-123",
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_delete",
|
||||
arguments={"path": "/important.txt"},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert result.safety_decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
assert result.approval_id == "approval-123"
|
||||
assert "requires human approval" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_emergency_stop(self, mock_guardian, mock_emergency):
|
||||
"""Test execution blocked by emergency stop."""
|
||||
mock_emergency.check_allowed.side_effect = EmergencyStopError(
|
||||
"Emergency stop active"
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_write",
|
||||
arguments={"path": "/test.txt"},
|
||||
project_id="proj-1",
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert result.safety_decision == SafetyDecision.DENY
|
||||
assert result.metadata.get("emergency_stop") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_bypass_safety(self, mock_guardian, mock_emergency):
|
||||
"""Test executing with safety bypass."""
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
async def handler(data: str) -> str:
|
||||
return f"Processed: {data}"
|
||||
|
||||
wrapper.register_tool_handler("custom_tool", handler)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="custom_tool",
|
||||
arguments={"data": "test"},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1", bypass_safety=True)
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == "Processed: test"
|
||||
# Guardian should not be called when bypassing
|
||||
mock_guardian.validate.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_no_handler(self, mock_guardian, mock_emergency):
|
||||
"""Test executing a tool with no registered handler."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="unregistered_tool",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert "No handler registered" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handler_exception(self, mock_guardian, mock_emergency):
|
||||
"""Test handling exceptions from tool handler."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
async def failing_handler() -> None:
|
||||
raise ValueError("Handler failed!")
|
||||
|
||||
wrapper.register_tool_handler("failing_tool", failing_handler)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="failing_tool",
|
||||
arguments={},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert "Handler failed!" in result.error
|
||||
# Decision is still ALLOW because the safety check passed
|
||||
assert result.safety_decision == SafetyDecision.ALLOW
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sync_handler(self, mock_guardian, mock_emergency):
|
||||
"""Test executing a synchronous handler."""
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
def sync_handler(value: int) -> int:
|
||||
return value * 2
|
||||
|
||||
wrapper.register_tool_handler("sync_tool", sync_handler)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="sync_tool",
|
||||
arguments={"value": 21},
|
||||
)
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == 42
|
||||
|
||||
|
||||
class TestBuildActionRequest:
|
||||
"""Tests for _build_action_request."""
|
||||
|
||||
def test_build_action_request_basic(self):
|
||||
"""Test building a basic action request."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="file_read",
|
||||
arguments={"path": "/test.txt"},
|
||||
project_id="proj-1",
|
||||
)
|
||||
|
||||
action = wrapper._build_action_request(call, "agent-1", AutonomyLevel.MILESTONE)
|
||||
|
||||
assert action.action_type == ActionType.FILE_READ
|
||||
assert action.tool_name == "file_read"
|
||||
assert action.arguments == {"path": "/test.txt"}
|
||||
assert action.resource == "/test.txt"
|
||||
assert action.metadata.agent_id == "agent-1"
|
||||
assert action.metadata.project_id == "proj-1"
|
||||
assert action.metadata.autonomy_level == AutonomyLevel.MILESTONE
|
||||
|
||||
def test_build_action_request_with_context(self):
|
||||
"""Test building action request with session context."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="database_query",
|
||||
arguments={"resource": "users", "query": "SELECT *"},
|
||||
context={"session_id": "sess-123"},
|
||||
project_id="proj-2",
|
||||
)
|
||||
|
||||
action = wrapper._build_action_request(
|
||||
call, "agent-2", AutonomyLevel.AUTONOMOUS
|
||||
)
|
||||
|
||||
assert action.resource == "users"
|
||||
assert action.metadata.session_id == "sess-123"
|
||||
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||
|
||||
def test_build_action_request_no_resource(self):
|
||||
"""Test building action request without resource."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="llm_generate",
|
||||
arguments={"prompt": "Hello"},
|
||||
)
|
||||
|
||||
action = wrapper._build_action_request(
|
||||
call, "agent-1", AutonomyLevel.FULL_CONTROL
|
||||
)
|
||||
|
||||
assert action.resource is None
|
||||
|
||||
|
||||
class TestElapsedTime:
|
||||
"""Tests for _elapsed_ms helper."""
|
||||
|
||||
def test_elapsed_ms(self):
|
||||
"""Test calculating elapsed time."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
start = datetime.utcnow() - timedelta(milliseconds=100)
|
||||
elapsed = wrapper._elapsed_ms(start)
|
||||
|
||||
# Should be at least 100ms, but allow some tolerance
|
||||
assert elapsed >= 99
|
||||
assert elapsed < 200
|
||||
|
||||
|
||||
class TestSafeToolExecutor:
|
||||
"""Tests for SafeToolExecutor context manager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_execute(self):
|
||||
"""Test executing within context manager."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
async def handler() -> str:
|
||||
return "success"
|
||||
|
||||
wrapper.register_tool_handler("test_tool", handler)
|
||||
|
||||
call = MCPToolCall(tool_name="test_tool", arguments={})
|
||||
|
||||
async with SafeToolExecutor(wrapper, call, "agent-1") as executor:
|
||||
result = await executor.execute()
|
||||
|
||||
assert result.success is True
|
||||
assert result.result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_result_property(self):
|
||||
"""Test accessing result via property."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: "data")
|
||||
|
||||
call = MCPToolCall(tool_name="tool", arguments={})
|
||||
executor = SafeToolExecutor(wrapper, call, "agent-1")
|
||||
|
||||
# Before execution
|
||||
assert executor.result is None
|
||||
|
||||
async with executor:
|
||||
await executor.execute()
|
||||
|
||||
# After execution
|
||||
assert executor.result is not None
|
||||
assert executor.result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_with_autonomy_level(self):
|
||||
"""Test executor with custom autonomy level."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: None)
|
||||
|
||||
call = MCPToolCall(tool_name="tool", arguments={})
|
||||
|
||||
async with SafeToolExecutor(
|
||||
wrapper, call, "agent-1", AutonomyLevel.AUTONOMOUS
|
||||
) as executor:
|
||||
await executor.execute()
|
||||
|
||||
# Check that guardian was called with correct autonomy level
|
||||
mock_guardian.validate.assert_called_once()
|
||||
action = mock_guardian.validate.call_args[0][0]
|
||||
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||
|
||||
|
||||
class TestCreateMCPWrapper:
|
||||
"""Tests for create_mcp_wrapper factory function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_wrapper_with_guardian(self):
|
||||
"""Test creating wrapper with provided guardian."""
|
||||
mock_guardian = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||
) as mock_get_emergency:
|
||||
mock_get_emergency.return_value = AsyncMock()
|
||||
|
||||
wrapper = await create_mcp_wrapper(guardian=mock_guardian)
|
||||
|
||||
assert wrapper._guardian is mock_guardian
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_wrapper_default_guardian(self):
|
||||
"""Test creating wrapper with default guardian."""
|
||||
with (
|
||||
patch(
|
||||
"app.services.safety.mcp.integration.get_safety_guardian"
|
||||
) as mock_get_guardian,
|
||||
patch(
|
||||
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||
) as mock_get_emergency,
|
||||
):
|
||||
mock_guardian = AsyncMock()
|
||||
mock_get_guardian.return_value = mock_guardian
|
||||
mock_get_emergency.return_value = AsyncMock()
|
||||
|
||||
wrapper = await create_mcp_wrapper()
|
||||
|
||||
assert wrapper._guardian is mock_guardian
|
||||
mock_get_guardian.assert_called_once()
|
||||
|
||||
|
||||
class TestLazyGetters:
|
||||
"""Tests for lazy getter methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_guardian_lazy(self):
|
||||
"""Test lazy guardian initialization."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
with patch(
|
||||
"app.services.safety.mcp.integration.get_safety_guardian"
|
||||
) as mock_get:
|
||||
mock_guardian = AsyncMock()
|
||||
mock_get.return_value = mock_guardian
|
||||
|
||||
guardian = await wrapper._get_guardian()
|
||||
|
||||
assert guardian is mock_guardian
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_guardian_cached(self):
|
||||
"""Test guardian is cached after first access."""
|
||||
mock_guardian = AsyncMock()
|
||||
wrapper = MCPSafetyWrapper(guardian=mock_guardian)
|
||||
|
||||
guardian = await wrapper._get_guardian()
|
||||
|
||||
assert guardian is mock_guardian
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_emergency_controls_lazy(self):
|
||||
"""Test lazy emergency controls initialization."""
|
||||
wrapper = MCPSafetyWrapper()
|
||||
|
||||
with patch(
|
||||
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||
) as mock_get:
|
||||
mock_emergency = AsyncMock()
|
||||
mock_get.return_value = mock_emergency
|
||||
|
||||
emergency = await wrapper._get_emergency_controls()
|
||||
|
||||
assert emergency is mock_emergency
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_emergency_controls_cached(self):
|
||||
"""Test emergency controls is cached after first access."""
|
||||
mock_emergency = AsyncMock()
|
||||
wrapper = MCPSafetyWrapper(emergency_controls=mock_emergency)
|
||||
|
||||
emergency = await wrapper._get_emergency_controls()
|
||||
|
||||
assert emergency is mock_emergency
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases and error handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_safety_error(self):
|
||||
"""Test handling SafetyError from guardian."""
|
||||
from app.services.safety.exceptions import SafetyError
|
||||
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.side_effect = SafetyError("Internal safety error")
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
call = MCPToolCall(tool_name="test", arguments={})
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is False
|
||||
assert "Internal safety error" in result.error
|
||||
assert result.safety_decision == SafetyDecision.DENY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_checkpoint_id(self):
|
||||
"""Test that checkpoint_id is propagated to result."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id="checkpoint-abc",
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: "result")
|
||||
|
||||
call = MCPToolCall(tool_name="tool", arguments={})
|
||||
|
||||
result = await wrapper.execute(call, "agent-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.checkpoint_id == "checkpoint-abc"
|
||||
|
||||
def test_destructive_tools_constant(self):
|
||||
"""Test DESTRUCTIVE_TOOLS class constant."""
|
||||
assert "file_write" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||
assert "file_delete" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||
assert "shell_execute" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||
assert "git_push" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||
|
||||
def test_read_only_tools_constant(self):
|
||||
"""Test READ_ONLY_TOOLS class constant."""
|
||||
assert "file_read" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||
assert "database_query" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||
assert "git_status" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||
assert "search" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scope_with_project_id(self):
|
||||
"""Test that scope is set correctly with project_id."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: None)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="tool",
|
||||
arguments={},
|
||||
project_id="proj-123",
|
||||
)
|
||||
|
||||
await wrapper.execute(call, "agent-1")
|
||||
|
||||
# Verify emergency check was called with project scope
|
||||
mock_emergency.check_allowed.assert_called_once()
|
||||
call_kwargs = mock_emergency.check_allowed.call_args
|
||||
assert "project:proj-123" in str(call_kwargs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scope_without_project_id(self):
|
||||
"""Test that scope falls back to agent when no project_id."""
|
||||
mock_guardian = AsyncMock()
|
||||
mock_guardian.validate.return_value = MagicMock(
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=[],
|
||||
approval_id=None,
|
||||
checkpoint_id=None,
|
||||
)
|
||||
|
||||
mock_emergency = AsyncMock()
|
||||
|
||||
wrapper = MCPSafetyWrapper(
|
||||
guardian=mock_guardian,
|
||||
emergency_controls=mock_emergency,
|
||||
)
|
||||
|
||||
wrapper.register_tool_handler("tool", lambda: None)
|
||||
|
||||
call = MCPToolCall(
|
||||
tool_name="tool",
|
||||
arguments={},
|
||||
# No project_id
|
||||
)
|
||||
|
||||
await wrapper.execute(call, "agent-555")
|
||||
|
||||
# Verify emergency check was called with agent scope
|
||||
mock_emergency.check_allowed.assert_called_once()
|
||||
call_kwargs = mock_emergency.check_allowed.call_args
|
||||
assert "agent:agent-555" in str(call_kwargs)
|
||||
747
backend/tests/services/safety/test_metrics.py
Normal file
747
backend/tests/services/safety/test_metrics.py
Normal file
@@ -0,0 +1,747 @@
|
||||
"""
|
||||
Tests for Safety Metrics Collector.
|
||||
|
||||
Tests cover:
|
||||
- MetricType, MetricValue, HistogramBucket data structures
|
||||
- SafetyMetrics counters, gauges, histograms
|
||||
- Prometheus format export
|
||||
- Summary and reset operations
|
||||
- Singleton pattern and convenience functions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.services.safety.metrics.collector import (
|
||||
HistogramBucket,
|
||||
MetricType,
|
||||
MetricValue,
|
||||
SafetyMetrics,
|
||||
get_safety_metrics,
|
||||
record_mcp_call,
|
||||
record_validation,
|
||||
)
|
||||
|
||||
|
||||
class TestMetricType:
|
||||
"""Tests for MetricType enum."""
|
||||
|
||||
def test_metric_types_exist(self):
|
||||
"""Test all metric types are defined."""
|
||||
assert MetricType.COUNTER == "counter"
|
||||
assert MetricType.GAUGE == "gauge"
|
||||
assert MetricType.HISTOGRAM == "histogram"
|
||||
|
||||
def test_metric_type_is_string(self):
|
||||
"""Test MetricType values are strings."""
|
||||
assert isinstance(MetricType.COUNTER.value, str)
|
||||
assert isinstance(MetricType.GAUGE.value, str)
|
||||
assert isinstance(MetricType.HISTOGRAM.value, str)
|
||||
|
||||
|
||||
class TestMetricValue:
|
||||
"""Tests for MetricValue dataclass."""
|
||||
|
||||
def test_metric_value_creation(self):
|
||||
"""Test creating a metric value."""
|
||||
mv = MetricValue(
|
||||
name="test_metric",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=42.0,
|
||||
labels={"env": "test"},
|
||||
)
|
||||
|
||||
assert mv.name == "test_metric"
|
||||
assert mv.metric_type == MetricType.COUNTER
|
||||
assert mv.value == 42.0
|
||||
assert mv.labels == {"env": "test"}
|
||||
assert mv.timestamp is not None
|
||||
|
||||
def test_metric_value_defaults(self):
|
||||
"""Test metric value default values."""
|
||||
mv = MetricValue(
|
||||
name="test",
|
||||
metric_type=MetricType.GAUGE,
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
assert mv.labels == {}
|
||||
assert mv.timestamp is not None
|
||||
|
||||
|
||||
class TestHistogramBucket:
|
||||
"""Tests for HistogramBucket dataclass."""
|
||||
|
||||
def test_histogram_bucket_creation(self):
|
||||
"""Test creating a histogram bucket."""
|
||||
bucket = HistogramBucket(le=0.5, count=10)
|
||||
|
||||
assert bucket.le == 0.5
|
||||
assert bucket.count == 10
|
||||
|
||||
def test_histogram_bucket_defaults(self):
|
||||
"""Test histogram bucket default count."""
|
||||
bucket = HistogramBucket(le=1.0)
|
||||
|
||||
assert bucket.le == 1.0
|
||||
assert bucket.count == 0
|
||||
|
||||
def test_histogram_bucket_infinity(self):
|
||||
"""Test histogram bucket with infinity."""
|
||||
bucket = HistogramBucket(le=float("inf"))
|
||||
|
||||
assert bucket.le == float("inf")
|
||||
|
||||
|
||||
class TestSafetyMetricsCounters:
|
||||
"""Tests for SafetyMetrics counter methods."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def metrics(self):
|
||||
"""Create fresh metrics instance."""
|
||||
return SafetyMetrics()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_validations(self, metrics):
|
||||
"""Test incrementing validations counter."""
|
||||
await metrics.inc_validations("allow")
|
||||
await metrics.inc_validations("allow")
|
||||
await metrics.inc_validations("deny", agent_id="agent-1")
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["total_validations"] == 3
|
||||
assert summary["denied_validations"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_approvals_requested(self, metrics):
|
||||
"""Test incrementing approval requests counter."""
|
||||
await metrics.inc_approvals_requested("normal")
|
||||
await metrics.inc_approvals_requested("urgent")
|
||||
await metrics.inc_approvals_requested() # default
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["approval_requests"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_approvals_granted(self, metrics):
|
||||
"""Test incrementing approvals granted counter."""
|
||||
await metrics.inc_approvals_granted()
|
||||
await metrics.inc_approvals_granted()
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["approvals_granted"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_approvals_denied(self, metrics):
|
||||
"""Test incrementing approvals denied counter."""
|
||||
await metrics.inc_approvals_denied("timeout")
|
||||
await metrics.inc_approvals_denied("policy")
|
||||
await metrics.inc_approvals_denied() # default manual
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["approvals_denied"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_rate_limit_exceeded(self, metrics):
|
||||
"""Test incrementing rate limit exceeded counter."""
|
||||
await metrics.inc_rate_limit_exceeded("requests_per_minute")
|
||||
await metrics.inc_rate_limit_exceeded("tokens_per_hour")
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["rate_limit_hits"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_budget_exceeded(self, metrics):
|
||||
"""Test incrementing budget exceeded counter."""
|
||||
await metrics.inc_budget_exceeded("daily_cost")
|
||||
await metrics.inc_budget_exceeded("monthly_tokens")
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["budget_exceeded"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_loops_detected(self, metrics):
|
||||
"""Test incrementing loops detected counter."""
|
||||
await metrics.inc_loops_detected("repetition")
|
||||
await metrics.inc_loops_detected("pattern")
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["loops_detected"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_emergency_events(self, metrics):
|
||||
"""Test incrementing emergency events counter."""
|
||||
await metrics.inc_emergency_events("pause", "project-1")
|
||||
await metrics.inc_emergency_events("stop", "agent-2")
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["emergency_events"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_content_filtered(self, metrics):
|
||||
"""Test incrementing content filtered counter."""
|
||||
await metrics.inc_content_filtered("profanity", "blocked")
|
||||
await metrics.inc_content_filtered("pii", "redacted")
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["content_filtered"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_checkpoints_created(self, metrics):
|
||||
"""Test incrementing checkpoints created counter."""
|
||||
await metrics.inc_checkpoints_created()
|
||||
await metrics.inc_checkpoints_created()
|
||||
await metrics.inc_checkpoints_created()
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["checkpoints_created"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_rollbacks_executed(self, metrics):
|
||||
"""Test incrementing rollbacks executed counter."""
|
||||
await metrics.inc_rollbacks_executed(success=True)
|
||||
await metrics.inc_rollbacks_executed(success=False)
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["rollbacks_executed"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inc_mcp_calls(self, metrics):
|
||||
"""Test incrementing MCP calls counter."""
|
||||
await metrics.inc_mcp_calls("search_knowledge", success=True)
|
||||
await metrics.inc_mcp_calls("run_code", success=False)
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["mcp_calls"] == 2
|
||||
|
||||
|
||||
class TestSafetyMetricsGauges:
|
||||
"""Tests for SafetyMetrics gauge methods."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def metrics(self):
|
||||
"""Create fresh metrics instance."""
|
||||
return SafetyMetrics()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_budget_remaining(self, metrics):
|
||||
"""Test setting budget remaining gauge."""
|
||||
await metrics.set_budget_remaining("project-1", "daily_cost", 50.0)
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
gauge_metrics = [m for m in all_metrics if m.name == "safety_budget_remaining"]
|
||||
assert len(gauge_metrics) == 1
|
||||
assert gauge_metrics[0].value == 50.0
|
||||
assert gauge_metrics[0].labels["scope"] == "project-1"
|
||||
assert gauge_metrics[0].labels["budget_type"] == "daily_cost"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_rate_limit_remaining(self, metrics):
|
||||
"""Test setting rate limit remaining gauge."""
|
||||
await metrics.set_rate_limit_remaining("agent-1", "requests_per_minute", 45)
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
gauge_metrics = [
|
||||
m for m in all_metrics if m.name == "safety_rate_limit_remaining"
|
||||
]
|
||||
assert len(gauge_metrics) == 1
|
||||
assert gauge_metrics[0].value == 45.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_pending_approvals(self, metrics):
|
||||
"""Test setting pending approvals gauge."""
|
||||
await metrics.set_pending_approvals(5)
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["pending_approvals"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_active_checkpoints(self, metrics):
|
||||
"""Test setting active checkpoints gauge."""
|
||||
await metrics.set_active_checkpoints(3)
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["active_checkpoints"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_emergency_state(self, metrics):
|
||||
"""Test setting emergency state gauge."""
|
||||
await metrics.set_emergency_state("project-1", "normal")
|
||||
await metrics.set_emergency_state("project-2", "paused")
|
||||
await metrics.set_emergency_state("project-3", "stopped")
|
||||
await metrics.set_emergency_state("project-4", "unknown")
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
state_metrics = [m for m in all_metrics if m.name == "safety_emergency_state"]
|
||||
assert len(state_metrics) == 4
|
||||
|
||||
# Check state values
|
||||
values_by_scope = {m.labels["scope"]: m.value for m in state_metrics}
|
||||
assert values_by_scope["project-1"] == 0.0 # normal
|
||||
assert values_by_scope["project-2"] == 1.0 # paused
|
||||
assert values_by_scope["project-3"] == 2.0 # stopped
|
||||
assert values_by_scope["project-4"] == -1.0 # unknown
|
||||
|
||||
|
||||
class TestSafetyMetricsHistograms:
|
||||
"""Tests for SafetyMetrics histogram methods."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def metrics(self):
|
||||
"""Create fresh metrics instance."""
|
||||
return SafetyMetrics()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observe_validation_latency(self, metrics):
|
||||
"""Test observing validation latency."""
|
||||
await metrics.observe_validation_latency(0.05)
|
||||
await metrics.observe_validation_latency(0.15)
|
||||
await metrics.observe_validation_latency(0.5)
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
|
||||
count_metric = next(
|
||||
(m for m in all_metrics if m.name == "validation_latency_seconds_count"),
|
||||
None,
|
||||
)
|
||||
assert count_metric is not None
|
||||
assert count_metric.value == 3.0
|
||||
|
||||
sum_metric = next(
|
||||
(m for m in all_metrics if m.name == "validation_latency_seconds_sum"),
|
||||
None,
|
||||
)
|
||||
assert sum_metric is not None
|
||||
assert abs(sum_metric.value - 0.7) < 0.001
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observe_approval_latency(self, metrics):
|
||||
"""Test observing approval latency."""
|
||||
await metrics.observe_approval_latency(1.5)
|
||||
await metrics.observe_approval_latency(3.0)
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
|
||||
count_metric = next(
|
||||
(m for m in all_metrics if m.name == "approval_latency_seconds_count"),
|
||||
None,
|
||||
)
|
||||
assert count_metric is not None
|
||||
assert count_metric.value == 2.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observe_mcp_execution_latency(self, metrics):
|
||||
"""Test observing MCP execution latency."""
|
||||
await metrics.observe_mcp_execution_latency(0.02)
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
|
||||
count_metric = next(
|
||||
(m for m in all_metrics if m.name == "mcp_execution_latency_seconds_count"),
|
||||
None,
|
||||
)
|
||||
assert count_metric is not None
|
||||
assert count_metric.value == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_histogram_bucket_updates(self, metrics):
|
||||
"""Test that histogram buckets are updated correctly."""
|
||||
# Add values to test bucket distribution
|
||||
await metrics.observe_validation_latency(0.005) # <= 0.01
|
||||
await metrics.observe_validation_latency(0.03) # <= 0.05
|
||||
await metrics.observe_validation_latency(0.07) # <= 0.1
|
||||
await metrics.observe_validation_latency(15.0) # <= inf
|
||||
|
||||
prometheus = await metrics.get_prometheus_format()
|
||||
|
||||
# Check that bucket counts are in output
|
||||
assert "validation_latency_seconds_bucket" in prometheus
|
||||
assert "le=" in prometheus
|
||||
|
||||
|
||||
class TestSafetyMetricsExport:
|
||||
"""Tests for SafetyMetrics export methods."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def metrics(self):
|
||||
"""Create fresh metrics instance with some data."""
|
||||
m = SafetyMetrics()
|
||||
|
||||
# Add some counters
|
||||
await m.inc_validations("allow")
|
||||
await m.inc_validations("deny", agent_id="agent-1")
|
||||
|
||||
# Add some gauges
|
||||
await m.set_pending_approvals(3)
|
||||
await m.set_budget_remaining("proj-1", "daily", 100.0)
|
||||
|
||||
# Add some histogram values
|
||||
await m.observe_validation_latency(0.1)
|
||||
|
||||
return m
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_metrics(self, metrics):
|
||||
"""Test getting all metrics."""
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
|
||||
assert len(all_metrics) > 0
|
||||
assert all(isinstance(m, MetricValue) for m in all_metrics)
|
||||
|
||||
# Check we have different types
|
||||
types = {m.metric_type for m in all_metrics}
|
||||
assert MetricType.COUNTER in types
|
||||
assert MetricType.GAUGE in types
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_prometheus_format(self, metrics):
|
||||
"""Test Prometheus format export."""
|
||||
output = await metrics.get_prometheus_format()
|
||||
|
||||
assert isinstance(output, str)
|
||||
assert "# TYPE" in output
|
||||
assert "counter" in output
|
||||
assert "gauge" in output
|
||||
assert "safety_validations_total" in output
|
||||
assert "safety_pending_approvals" in output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prometheus_format_with_labels(self, metrics):
|
||||
"""Test Prometheus format includes labels correctly."""
|
||||
output = await metrics.get_prometheus_format()
|
||||
|
||||
# Counter with labels
|
||||
assert "decision=allow" in output or "decision=deny" in output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prometheus_format_histogram_buckets(self, metrics):
|
||||
"""Test Prometheus format includes histogram buckets."""
|
||||
output = await metrics.get_prometheus_format()
|
||||
|
||||
assert "histogram" in output
|
||||
assert "_bucket" in output
|
||||
assert "le=" in output
|
||||
assert "+Inf" in output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_summary(self, metrics):
|
||||
"""Test getting summary."""
|
||||
summary = await metrics.get_summary()
|
||||
|
||||
assert "total_validations" in summary
|
||||
assert "denied_validations" in summary
|
||||
assert "approval_requests" in summary
|
||||
assert "pending_approvals" in summary
|
||||
assert "active_checkpoints" in summary
|
||||
|
||||
assert summary["total_validations"] == 2
|
||||
assert summary["denied_validations"] == 1
|
||||
assert summary["pending_approvals"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_empty_counters(self):
|
||||
"""Test summary with no data."""
|
||||
metrics = SafetyMetrics()
|
||||
summary = await metrics.get_summary()
|
||||
|
||||
assert summary["total_validations"] == 0
|
||||
assert summary["denied_validations"] == 0
|
||||
assert summary["pending_approvals"] == 0
|
||||
|
||||
|
||||
class TestSafetyMetricsReset:
|
||||
"""Tests for SafetyMetrics reset."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_counters(self):
|
||||
"""Test reset clears all counters."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
await metrics.inc_validations("allow")
|
||||
await metrics.inc_approvals_granted()
|
||||
await metrics.set_pending_approvals(5)
|
||||
await metrics.observe_validation_latency(0.1)
|
||||
|
||||
await metrics.reset()
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["total_validations"] == 0
|
||||
assert summary["approvals_granted"] == 0
|
||||
assert summary["pending_approvals"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_reinitializes_histogram_buckets(self):
|
||||
"""Test reset reinitializes histogram buckets."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
await metrics.observe_validation_latency(0.1)
|
||||
await metrics.reset()
|
||||
|
||||
# After reset, histogram buckets should be reinitialized
|
||||
prometheus = await metrics.get_prometheus_format()
|
||||
assert "validation_latency_seconds" in prometheus
|
||||
|
||||
|
||||
class TestParseLabels:
|
||||
"""Tests for _parse_labels helper method."""
|
||||
|
||||
def test_parse_empty_labels(self):
|
||||
"""Test parsing empty labels string."""
|
||||
metrics = SafetyMetrics()
|
||||
result = metrics._parse_labels("")
|
||||
assert result == {}
|
||||
|
||||
def test_parse_single_label(self):
|
||||
"""Test parsing single label."""
|
||||
metrics = SafetyMetrics()
|
||||
result = metrics._parse_labels("key=value")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_parse_multiple_labels(self):
|
||||
"""Test parsing multiple labels."""
|
||||
metrics = SafetyMetrics()
|
||||
result = metrics._parse_labels("a=1,b=2,c=3")
|
||||
assert result == {"a": "1", "b": "2", "c": "3"}
|
||||
|
||||
def test_parse_labels_with_spaces(self):
|
||||
"""Test parsing labels with spaces."""
|
||||
metrics = SafetyMetrics()
|
||||
result = metrics._parse_labels(" key = value , foo = bar ")
|
||||
assert result == {"key": "value", "foo": "bar"}
|
||||
|
||||
def test_parse_labels_with_equals_in_value(self):
|
||||
"""Test parsing labels with = in value."""
|
||||
metrics = SafetyMetrics()
|
||||
result = metrics._parse_labels("query=a=b")
|
||||
assert result == {"query": "a=b"}
|
||||
|
||||
def test_parse_invalid_label(self):
|
||||
"""Test parsing invalid label without equals."""
|
||||
metrics = SafetyMetrics()
|
||||
result = metrics._parse_labels("no_equals")
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestHistogramBucketInit:
|
||||
"""Tests for histogram bucket initialization."""
|
||||
|
||||
def test_histogram_buckets_initialized(self):
|
||||
"""Test that histogram buckets are initialized."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
assert "validation_latency_seconds" in metrics._histogram_buckets
|
||||
assert "approval_latency_seconds" in metrics._histogram_buckets
|
||||
assert "mcp_execution_latency_seconds" in metrics._histogram_buckets
|
||||
|
||||
def test_histogram_buckets_have_correct_values(self):
|
||||
"""Test histogram buckets have correct boundary values."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
buckets = metrics._histogram_buckets["validation_latency_seconds"]
|
||||
|
||||
# Check first few and last bucket
|
||||
assert buckets[0].le == 0.01
|
||||
assert buckets[1].le == 0.05
|
||||
assert buckets[-1].le == float("inf")
|
||||
|
||||
# Check all have zero initial count
|
||||
assert all(b.count == 0 for b in buckets)
|
||||
|
||||
|
||||
class TestSingletonAndConvenience:
|
||||
"""Tests for singleton pattern and convenience functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_safety_metrics_returns_same_instance(self):
|
||||
"""Test get_safety_metrics returns singleton."""
|
||||
# Reset the module-level singleton for this test
|
||||
import app.services.safety.metrics.collector as collector_module
|
||||
|
||||
collector_module._metrics = None
|
||||
|
||||
m1 = await get_safety_metrics()
|
||||
m2 = await get_safety_metrics()
|
||||
|
||||
assert m1 is m2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_validation_convenience(self):
|
||||
"""Test record_validation convenience function."""
|
||||
import app.services.safety.metrics.collector as collector_module
|
||||
|
||||
collector_module._metrics = None # Reset
|
||||
|
||||
await record_validation("allow")
|
||||
await record_validation("deny", agent_id="test-agent")
|
||||
|
||||
metrics = await get_safety_metrics()
|
||||
summary = await metrics.get_summary()
|
||||
|
||||
assert summary["total_validations"] == 2
|
||||
assert summary["denied_validations"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_mcp_call_convenience(self):
|
||||
"""Test record_mcp_call convenience function."""
|
||||
import app.services.safety.metrics.collector as collector_module
|
||||
|
||||
collector_module._metrics = None # Reset
|
||||
|
||||
await record_mcp_call("search_knowledge", success=True, latency_ms=50)
|
||||
await record_mcp_call("run_code", success=False, latency_ms=100)
|
||||
|
||||
metrics = await get_safety_metrics()
|
||||
summary = await metrics.get_summary()
|
||||
|
||||
assert summary["mcp_calls"] == 2
|
||||
|
||||
|
||||
class TestConcurrency:
|
||||
"""Tests for concurrent metric updates."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_counter_increments(self):
|
||||
"""Test concurrent counter increments are safe."""
|
||||
import asyncio
|
||||
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
async def increment_many():
|
||||
for _ in range(100):
|
||||
await metrics.inc_validations("allow")
|
||||
|
||||
# Run 10 concurrent tasks each incrementing 100 times
|
||||
await asyncio.gather(*[increment_many() for _ in range(10)])
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["total_validations"] == 1000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_gauge_updates(self):
|
||||
"""Test concurrent gauge updates are safe."""
|
||||
import asyncio
|
||||
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
async def update_gauge(value):
|
||||
await metrics.set_pending_approvals(value)
|
||||
|
||||
# Run concurrent gauge updates
|
||||
await asyncio.gather(*[update_gauge(i) for i in range(100)])
|
||||
|
||||
# Final value should be one of the updates (last one wins)
|
||||
summary = await metrics.get_summary()
|
||||
assert 0 <= summary["pending_approvals"] < 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_histogram_observations(self):
|
||||
"""Test concurrent histogram observations are safe."""
|
||||
import asyncio
|
||||
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
async def observe_many():
|
||||
for i in range(100):
|
||||
await metrics.observe_validation_latency(i / 1000)
|
||||
|
||||
await asyncio.gather(*[observe_many() for _ in range(10)])
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
count_metric = next(
|
||||
(m for m in all_metrics if m.name == "validation_latency_seconds_count"),
|
||||
None,
|
||||
)
|
||||
assert count_metric is not None
|
||||
assert count_metric.value == 1000.0
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_large_counter_value(self):
|
||||
"""Test handling very large counter values."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
for _ in range(10000):
|
||||
await metrics.inc_validations("allow")
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["total_validations"] == 10000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_and_negative_gauge_values(self):
|
||||
"""Test zero and negative gauge values."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
await metrics.set_budget_remaining("project", "cost", 0.0)
|
||||
await metrics.set_budget_remaining("project2", "cost", -10.0)
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
gauges = [m for m in all_metrics if m.name == "safety_budget_remaining"]
|
||||
|
||||
values = {m.labels.get("scope"): m.value for m in gauges}
|
||||
assert values["project"] == 0.0
|
||||
assert values["project2"] == -10.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_small_histogram_values(self):
|
||||
"""Test very small histogram values."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
await metrics.observe_validation_latency(0.0001) # 0.1ms
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
sum_metric = next(
|
||||
(m for m in all_metrics if m.name == "validation_latency_seconds_sum"),
|
||||
None,
|
||||
)
|
||||
assert sum_metric is not None
|
||||
assert abs(sum_metric.value - 0.0001) < 0.00001
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_special_characters_in_labels(self):
|
||||
"""Test special characters in label values."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
await metrics.inc_validations("allow", agent_id="agent/with/slashes")
|
||||
|
||||
all_metrics = await metrics.get_all_metrics()
|
||||
counters = [m for m in all_metrics if m.name == "safety_validations_total"]
|
||||
|
||||
# Should have the metric with special chars
|
||||
assert len(counters) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_histogram_export(self):
|
||||
"""Test exporting histogram with no observations."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
# No observations, but histogram buckets should still exist
|
||||
prometheus = await metrics.get_prometheus_format()
|
||||
|
||||
assert "validation_latency_seconds" in prometheus
|
||||
assert "le=" in prometheus
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prometheus_format_empty_label_value(self):
|
||||
"""Test Prometheus format with empty label metrics."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
await metrics.inc_approvals_granted() # Uses empty string as label
|
||||
|
||||
prometheus = await metrics.get_prometheus_format()
|
||||
assert "safety_approvals_granted_total" in prometheus
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_resets(self):
|
||||
"""Test multiple resets don't cause issues."""
|
||||
metrics = SafetyMetrics()
|
||||
|
||||
await metrics.inc_validations("allow")
|
||||
await metrics.reset()
|
||||
await metrics.reset()
|
||||
await metrics.reset()
|
||||
|
||||
summary = await metrics.get_summary()
|
||||
assert summary["total_validations"] == 0
|
||||
933
backend/tests/services/safety/test_permissions.py
Normal file
933
backend/tests/services/safety/test_permissions.py
Normal file
@@ -0,0 +1,933 @@
|
||||
"""Tests for Permission Manager.
|
||||
|
||||
Tests cover:
|
||||
- PermissionGrant: creation, expiry, matching, hierarchy
|
||||
- PermissionManager: grant, revoke, check, require, list, defaults
|
||||
- Edge cases: wildcards, expiration, default deny/allow
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.services.safety.exceptions import PermissionDeniedError
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
PermissionLevel,
|
||||
ResourceType,
|
||||
)
|
||||
from app.services.safety.permissions.manager import PermissionGrant, PermissionManager
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_metadata() -> ActionMetadata:
|
||||
"""Create standard action metadata for tests."""
|
||||
return ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
project_id="test-project",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def permission_manager() -> PermissionManager:
|
||||
"""Create a PermissionManager for testing."""
|
||||
return PermissionManager(default_deny=True)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def permissive_manager() -> PermissionManager:
|
||||
"""Create a PermissionManager with default_deny=False."""
|
||||
return PermissionManager(default_deny=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PermissionGrant Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPermissionGrant:
|
||||
"""Tests for the PermissionGrant class."""
|
||||
|
||||
def test_grant_creation(self) -> None:
|
||||
"""Test basic grant creation."""
|
||||
grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
granted_by="admin",
|
||||
reason="Read access to data directory",
|
||||
)
|
||||
|
||||
assert grant.id is not None
|
||||
assert grant.agent_id == "agent-1"
|
||||
assert grant.resource_pattern == "/data/*"
|
||||
assert grant.resource_type == ResourceType.FILE
|
||||
assert grant.level == PermissionLevel.READ
|
||||
assert grant.granted_by == "admin"
|
||||
assert grant.reason == "Read access to data directory"
|
||||
assert grant.expires_at is None
|
||||
assert grant.created_at is not None
|
||||
|
||||
def test_grant_with_expiration(self) -> None:
|
||||
"""Test grant with expiration time."""
|
||||
future = datetime.utcnow() + timedelta(hours=1)
|
||||
grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.API,
|
||||
level=PermissionLevel.EXECUTE,
|
||||
expires_at=future,
|
||||
)
|
||||
|
||||
assert grant.expires_at == future
|
||||
assert grant.is_expired() is False
|
||||
|
||||
def test_is_expired_no_expiration(self) -> None:
|
||||
"""Test is_expired with no expiration set."""
|
||||
grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert grant.is_expired() is False
|
||||
|
||||
def test_is_expired_future(self) -> None:
|
||||
"""Test is_expired with future expiration."""
|
||||
grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
expires_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert grant.is_expired() is False
|
||||
|
||||
def test_is_expired_past(self) -> None:
|
||||
"""Test is_expired with past expiration."""
|
||||
grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert grant.is_expired() is True
|
||||
|
||||
def test_matches_exact(self) -> None:
|
||||
"""Test matching with exact pattern."""
|
||||
grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
|
||||
assert grant.matches("/data/other.txt", ResourceType.FILE) is False
|
||||
|
||||
def test_matches_wildcard(self) -> None:
|
||||
"""Test matching with wildcard pattern."""
|
||||
grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
|
||||
# fnmatch's * matches everything including /
|
||||
assert grant.matches("/data/subdir/file.txt", ResourceType.FILE) is True
|
||||
assert grant.matches("/other/file.txt", ResourceType.FILE) is False
|
||||
|
||||
def test_matches_recursive_wildcard(self) -> None:
|
||||
"""Test matching with recursive pattern."""
|
||||
grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/**",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
# fnmatch treats ** similar to * - both match everything including /
|
||||
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
|
||||
assert grant.matches("/data/subdir/file.txt", ResourceType.FILE) is True
|
||||
|
||||
def test_matches_wrong_resource_type(self) -> None:
|
||||
"""Test matching fails with wrong resource type."""
|
||||
grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
# Same pattern but different resource type
|
||||
assert grant.matches("/data/table", ResourceType.DATABASE) is False
|
||||
|
||||
def test_allows_hierarchy(self) -> None:
|
||||
"""Test permission level hierarchy."""
|
||||
admin_grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.ADMIN,
|
||||
)
|
||||
|
||||
# ADMIN allows all levels
|
||||
assert admin_grant.allows(PermissionLevel.NONE) is True
|
||||
assert admin_grant.allows(PermissionLevel.READ) is True
|
||||
assert admin_grant.allows(PermissionLevel.WRITE) is True
|
||||
assert admin_grant.allows(PermissionLevel.EXECUTE) is True
|
||||
assert admin_grant.allows(PermissionLevel.DELETE) is True
|
||||
assert admin_grant.allows(PermissionLevel.ADMIN) is True
|
||||
|
||||
def test_allows_read_only(self) -> None:
|
||||
"""Test READ grant only allows READ and NONE."""
|
||||
read_grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert read_grant.allows(PermissionLevel.NONE) is True
|
||||
assert read_grant.allows(PermissionLevel.READ) is True
|
||||
assert read_grant.allows(PermissionLevel.WRITE) is False
|
||||
assert read_grant.allows(PermissionLevel.EXECUTE) is False
|
||||
assert read_grant.allows(PermissionLevel.DELETE) is False
|
||||
assert read_grant.allows(PermissionLevel.ADMIN) is False
|
||||
|
||||
def test_allows_write_includes_read(self) -> None:
|
||||
"""Test WRITE grant includes READ."""
|
||||
write_grant = PermissionGrant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.WRITE,
|
||||
)
|
||||
|
||||
assert write_grant.allows(PermissionLevel.READ) is True
|
||||
assert write_grant.allows(PermissionLevel.WRITE) is True
|
||||
assert write_grant.allows(PermissionLevel.EXECUTE) is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PermissionManager Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPermissionManager:
|
||||
"""Tests for the PermissionManager class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grant_creates_permission(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test granting a permission."""
|
||||
grant = await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
granted_by="admin",
|
||||
reason="Read access",
|
||||
)
|
||||
|
||||
assert grant.id is not None
|
||||
assert grant.agent_id == "agent-1"
|
||||
assert grant.resource_pattern == "/data/*"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grant_with_duration(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test granting a temporary permission."""
|
||||
grant = await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.API,
|
||||
level=PermissionLevel.EXECUTE,
|
||||
duration_seconds=3600, # 1 hour
|
||||
)
|
||||
|
||||
assert grant.expires_at is not None
|
||||
assert grant.is_expired() is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_by_id(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test revoking a grant by ID."""
|
||||
grant = await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
success = await permission_manager.revoke(grant.id)
|
||||
assert success is True
|
||||
|
||||
# Verify grant is removed
|
||||
grants = await permission_manager.list_grants(agent_id="agent-1")
|
||||
assert len(grants) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_nonexistent(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test revoking a non-existent grant."""
|
||||
success = await permission_manager.revoke("nonexistent-id")
|
||||
assert success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_all_for_agent(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test revoking all permissions for an agent."""
|
||||
# Grant multiple permissions
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/api/*",
|
||||
resource_type=ResourceType.API,
|
||||
level=PermissionLevel.EXECUTE,
|
||||
)
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-2",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
revoked = await permission_manager.revoke_all("agent-1")
|
||||
assert revoked == 2
|
||||
|
||||
# Verify agent-1 grants are gone
|
||||
grants = await permission_manager.list_grants(agent_id="agent-1")
|
||||
assert len(grants) == 0
|
||||
|
||||
# Verify agent-2 grant remains
|
||||
grants = await permission_manager.list_grants(agent_id="agent-2")
|
||||
assert len(grants) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_all_no_grants(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test revoking all when no grants exist."""
|
||||
revoked = await permission_manager.revoke_all("nonexistent-agent")
|
||||
assert revoked == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_granted(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test checking a granted permission."""
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
allowed = await permission_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="/data/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert allowed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_denied_default_deny(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test checking denied with default_deny=True."""
|
||||
# No grants, should be denied
|
||||
allowed = await permission_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="/data/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert allowed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_uses_default_permissions(
|
||||
self,
|
||||
permissive_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test that default permissions apply when default_deny=False."""
|
||||
# No explicit grants, but FILE default is READ
|
||||
allowed = await permissive_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="/data/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert allowed is True
|
||||
|
||||
# But WRITE should fail
|
||||
allowed = await permissive_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="/data/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.WRITE,
|
||||
)
|
||||
|
||||
assert allowed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_shell_denied_by_default(
|
||||
self,
|
||||
permissive_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test SHELL is denied by default (NONE level)."""
|
||||
allowed = await permissive_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="rm -rf /",
|
||||
resource_type=ResourceType.SHELL,
|
||||
required_level=PermissionLevel.EXECUTE,
|
||||
)
|
||||
|
||||
assert allowed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_expired_grant_ignored(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test that expired grants are ignored in checks."""
|
||||
# Create an already-expired grant
|
||||
grant = await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
duration_seconds=1, # Very short
|
||||
)
|
||||
|
||||
# Manually expire it
|
||||
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
|
||||
|
||||
allowed = await permission_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="/data/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert allowed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_insufficient_level(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test check fails when grant level is insufficient."""
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
# Try to get WRITE access with only READ grant
|
||||
allowed = await permission_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="/data/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.WRITE,
|
||||
)
|
||||
|
||||
assert allowed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_file_read(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
action_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test check_action for file read."""
|
||||
await permission_manager.grant(
|
||||
agent_id="test-agent",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
resource="/data/file.txt",
|
||||
metadata=action_metadata,
|
||||
)
|
||||
|
||||
allowed = await permission_manager.check_action(action)
|
||||
assert allowed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_file_write(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
action_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test check_action for file write."""
|
||||
await permission_manager.grant(
|
||||
agent_id="test-agent",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.WRITE,
|
||||
)
|
||||
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_WRITE,
|
||||
resource="/data/file.txt",
|
||||
metadata=action_metadata,
|
||||
)
|
||||
|
||||
allowed = await permission_manager.check_action(action)
|
||||
assert allowed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_uses_tool_name_as_resource(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
action_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test check_action uses tool_name when resource is None."""
|
||||
await permission_manager.grant(
|
||||
agent_id="test-agent",
|
||||
resource_pattern="search_*",
|
||||
resource_type=ResourceType.CUSTOM,
|
||||
level=PermissionLevel.EXECUTE,
|
||||
)
|
||||
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.TOOL_CALL,
|
||||
tool_name="search_documents",
|
||||
resource=None,
|
||||
metadata=action_metadata,
|
||||
)
|
||||
|
||||
allowed = await permission_manager.check_action(action)
|
||||
assert allowed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_permission_granted(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test require_permission doesn't raise when granted."""
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
await permission_manager.require_permission(
|
||||
agent_id="agent-1",
|
||||
resource="/data/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_permission_denied(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test require_permission raises when denied."""
|
||||
with pytest.raises(PermissionDeniedError) as exc_info:
|
||||
await permission_manager.require_permission(
|
||||
agent_id="agent-1",
|
||||
resource="/secret/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert "/secret/file.txt" in str(exc_info.value)
|
||||
assert exc_info.value.agent_id == "agent-1"
|
||||
assert exc_info.value.required_permission == "read"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_grants_all(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test listing all grants."""
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-2",
|
||||
resource_pattern="/api/*",
|
||||
resource_type=ResourceType.API,
|
||||
level=PermissionLevel.EXECUTE,
|
||||
)
|
||||
|
||||
grants = await permission_manager.list_grants()
|
||||
assert len(grants) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_grants_by_agent(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test listing grants filtered by agent."""
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-2",
|
||||
resource_pattern="/api/*",
|
||||
resource_type=ResourceType.API,
|
||||
level=PermissionLevel.EXECUTE,
|
||||
)
|
||||
|
||||
grants = await permission_manager.list_grants(agent_id="agent-1")
|
||||
assert len(grants) == 1
|
||||
assert grants[0].agent_id == "agent-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_grants_by_resource_type(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test listing grants filtered by resource type."""
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/api/*",
|
||||
resource_type=ResourceType.API,
|
||||
level=PermissionLevel.EXECUTE,
|
||||
)
|
||||
|
||||
grants = await permission_manager.list_grants(resource_type=ResourceType.FILE)
|
||||
assert len(grants) == 1
|
||||
assert grants[0].resource_type == ResourceType.FILE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_grants_excludes_expired(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test that list_grants excludes expired grants."""
|
||||
# Create expired grant
|
||||
grant = await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/old/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
duration_seconds=1,
|
||||
)
|
||||
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
|
||||
|
||||
# Create valid grant
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/new/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
grants = await permission_manager.list_grants()
|
||||
assert len(grants) == 1
|
||||
assert grants[0].resource_pattern == "/new/*"
|
||||
|
||||
def test_set_default_permission(
|
||||
self,
|
||||
) -> None:
|
||||
"""Test setting default permission level."""
|
||||
manager = PermissionManager(default_deny=False)
|
||||
|
||||
# Default for SHELL is NONE
|
||||
assert manager._default_permissions[ResourceType.SHELL] == PermissionLevel.NONE
|
||||
|
||||
# Change it
|
||||
manager.set_default_permission(ResourceType.SHELL, PermissionLevel.EXECUTE)
|
||||
assert (
|
||||
manager._default_permissions[ResourceType.SHELL] == PermissionLevel.EXECUTE
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_default_permission_affects_checks(
|
||||
self,
|
||||
permissive_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test that changing default permissions affects checks."""
|
||||
# Initially SHELL is NONE
|
||||
allowed = await permissive_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="ls",
|
||||
resource_type=ResourceType.SHELL,
|
||||
required_level=PermissionLevel.EXECUTE,
|
||||
)
|
||||
assert allowed is False
|
||||
|
||||
# Change default
|
||||
permissive_manager.set_default_permission(
|
||||
ResourceType.SHELL, PermissionLevel.EXECUTE
|
||||
)
|
||||
|
||||
# Now should be allowed
|
||||
allowed = await permissive_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="ls",
|
||||
resource_type=ResourceType.SHELL,
|
||||
required_level=PermissionLevel.EXECUTE,
|
||||
)
|
||||
assert allowed is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Cases
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPermissionEdgeCases:
|
||||
"""Edge cases that could reveal hidden bugs."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_matching_grants(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test when multiple grants match - first sufficient one wins."""
|
||||
# Grant READ on all files
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
# Also grant WRITE on specific path
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/writable/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.WRITE,
|
||||
)
|
||||
|
||||
# Write on writable path should work
|
||||
allowed = await permission_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="/data/writable/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.WRITE,
|
||||
)
|
||||
assert allowed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_all_pattern(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test * pattern matches everything."""
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.ADMIN,
|
||||
)
|
||||
|
||||
allowed = await permission_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="/any/path/anywhere/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.DELETE,
|
||||
)
|
||||
|
||||
# fnmatch's * matches everything including /
|
||||
assert allowed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_mark_wildcard(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test ? wildcard matches single character."""
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="file?.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert (
|
||||
await permission_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="file1.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
assert (
|
||||
await permission_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="file10.txt", # Two characters, won't match
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_grant_revoke(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test concurrent grant and revoke operations."""
|
||||
|
||||
async def grant_many():
|
||||
grants = []
|
||||
for i in range(10):
|
||||
g = await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern=f"/path{i}/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
grants.append(g)
|
||||
return grants
|
||||
|
||||
async def revoke_many(grants):
|
||||
for g in grants:
|
||||
await permission_manager.revoke(g.id)
|
||||
|
||||
grants = await grant_many()
|
||||
await revoke_many(grants)
|
||||
|
||||
# All should be revoked
|
||||
remaining = await permission_manager.list_grants(agent_id="agent-1")
|
||||
assert len(remaining) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_action_with_no_resource_or_tool(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
action_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test check_action when both resource and tool_name are None."""
|
||||
await permission_manager.grant(
|
||||
agent_id="test-agent",
|
||||
resource_pattern="*",
|
||||
resource_type=ResourceType.LLM,
|
||||
level=PermissionLevel.EXECUTE,
|
||||
)
|
||||
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.LLM_CALL,
|
||||
resource=None,
|
||||
tool_name=None,
|
||||
metadata=action_metadata,
|
||||
)
|
||||
|
||||
# Should use "*" as fallback
|
||||
allowed = await permission_manager.check_action(action)
|
||||
assert allowed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_called_on_check(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test that expired grants are cleaned up during check."""
|
||||
# Create expired grant
|
||||
grant = await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/old/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
duration_seconds=1,
|
||||
)
|
||||
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
|
||||
|
||||
# Create valid grant
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/new/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
# Run a check - this should trigger cleanup
|
||||
await permission_manager.check(
|
||||
agent_id="agent-1",
|
||||
resource="/new/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
# Now verify expired grant was cleaned up
|
||||
async with permission_manager._lock:
|
||||
assert len(permission_manager._grants) == 1
|
||||
assert permission_manager._grants[0].resource_pattern == "/new/*"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_wrong_agent_id(
|
||||
self,
|
||||
permission_manager: PermissionManager,
|
||||
) -> None:
|
||||
"""Test check fails for different agent."""
|
||||
await permission_manager.grant(
|
||||
agent_id="agent-1",
|
||||
resource_pattern="/data/*",
|
||||
resource_type=ResourceType.FILE,
|
||||
level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
# Different agent should not have access
|
||||
allowed = await permission_manager.check(
|
||||
agent_id="agent-2",
|
||||
resource="/data/file.txt",
|
||||
resource_type=ResourceType.FILE,
|
||||
required_level=PermissionLevel.READ,
|
||||
)
|
||||
|
||||
assert allowed is False
|
||||
823
backend/tests/services/safety/test_rollback.py
Normal file
823
backend/tests/services/safety/test_rollback.py
Normal file
@@ -0,0 +1,823 @@
|
||||
"""Tests for Rollback Manager.
|
||||
|
||||
Tests cover:
|
||||
- FileCheckpoint: state storage
|
||||
- RollbackManager: checkpoint, rollback, cleanup
|
||||
- TransactionContext: auto-rollback, commit, manual rollback
|
||||
- Edge cases: non-existent files, partial failures, expiration
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.services.safety.exceptions import RollbackError
|
||||
from app.services.safety.models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
CheckpointType,
|
||||
)
|
||||
from app.services.safety.rollback.manager import (
|
||||
FileCheckpoint,
|
||||
RollbackManager,
|
||||
TransactionContext,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_metadata() -> ActionMetadata:
|
||||
"""Create standard action metadata for tests."""
|
||||
return ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
project_id="test-project",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_request(action_metadata: ActionMetadata) -> ActionRequest:
|
||||
"""Create a standard action request for tests."""
|
||||
return ActionRequest(
|
||||
id="action-123",
|
||||
action_type=ActionType.FILE_WRITE,
|
||||
tool_name="file_write",
|
||||
resource="/tmp/test_file.txt", # noqa: S108
|
||||
metadata=action_metadata,
|
||||
is_destructive=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def rollback_manager() -> RollbackManager:
|
||||
"""Create a RollbackManager for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with patch("app.services.safety.rollback.manager.get_safety_config") as mock:
|
||||
mock.return_value = MagicMock(
|
||||
checkpoint_dir=tmpdir,
|
||||
checkpoint_retention_hours=24,
|
||||
)
|
||||
manager = RollbackManager(checkpoint_dir=tmpdir, retention_hours=24)
|
||||
yield manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir() -> Path:
|
||||
"""Create a temporary directory for file operations."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FileCheckpoint Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestFileCheckpoint:
|
||||
"""Tests for the FileCheckpoint class."""
|
||||
|
||||
def test_file_checkpoint_creation(self) -> None:
|
||||
"""Test creating a file checkpoint."""
|
||||
fc = FileCheckpoint(
|
||||
checkpoint_id="cp-123",
|
||||
file_path="/path/to/file.txt",
|
||||
original_content=b"original content",
|
||||
existed=True,
|
||||
)
|
||||
|
||||
assert fc.checkpoint_id == "cp-123"
|
||||
assert fc.file_path == "/path/to/file.txt"
|
||||
assert fc.original_content == b"original content"
|
||||
assert fc.existed is True
|
||||
assert fc.created_at is not None
|
||||
|
||||
def test_file_checkpoint_nonexistent_file(self) -> None:
|
||||
"""Test checkpoint for non-existent file."""
|
||||
fc = FileCheckpoint(
|
||||
checkpoint_id="cp-123",
|
||||
file_path="/path/to/new_file.txt",
|
||||
original_content=None,
|
||||
existed=False,
|
||||
)
|
||||
|
||||
assert fc.original_content is None
|
||||
assert fc.existed is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RollbackManager Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestRollbackManager:
|
||||
"""Tests for the RollbackManager class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkpoint(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test creating a checkpoint."""
|
||||
checkpoint = await rollback_manager.create_checkpoint(
|
||||
action=action_request,
|
||||
checkpoint_type=CheckpointType.FILE,
|
||||
description="Test checkpoint",
|
||||
)
|
||||
|
||||
assert checkpoint.id is not None
|
||||
assert checkpoint.action_id == action_request.id
|
||||
assert checkpoint.checkpoint_type == CheckpointType.FILE
|
||||
assert checkpoint.description == "Test checkpoint"
|
||||
assert checkpoint.expires_at is not None
|
||||
assert checkpoint.is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkpoint_default_description(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test checkpoint with default description."""
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
|
||||
assert "file_write" in checkpoint.description
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_file_exists(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test checkpointing an existing file."""
|
||||
# Create a file
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original content")
|
||||
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||
|
||||
# Verify checkpoint was stored
|
||||
async with rollback_manager._lock:
|
||||
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
||||
assert len(file_checkpoints) == 1
|
||||
assert file_checkpoints[0].existed is True
|
||||
assert file_checkpoints[0].original_content == b"original content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_file_not_exists(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test checkpointing a non-existent file."""
|
||||
test_file = temp_dir / "new_file.txt"
|
||||
assert not test_file.exists()
|
||||
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||
|
||||
# Verify checkpoint was stored
|
||||
async with rollback_manager._lock:
|
||||
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
||||
assert len(file_checkpoints) == 1
|
||||
assert file_checkpoints[0].existed is False
|
||||
assert file_checkpoints[0].original_content is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_files_multiple(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test checkpointing multiple files."""
|
||||
# Create files
|
||||
file1 = temp_dir / "file1.txt"
|
||||
file2 = temp_dir / "file2.txt"
|
||||
file1.write_text("content 1")
|
||||
file2.write_text("content 2")
|
||||
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.checkpoint_files(
|
||||
checkpoint.id,
|
||||
[str(file1), str(file2)],
|
||||
)
|
||||
|
||||
async with rollback_manager._lock:
|
||||
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
||||
assert len(file_checkpoints) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_restore_modified_file(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test rollback restores modified file content."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original content")
|
||||
|
||||
# Create checkpoint
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||
|
||||
# Modify file
|
||||
test_file.write_text("modified content")
|
||||
assert test_file.read_text() == "modified content"
|
||||
|
||||
# Rollback
|
||||
result = await rollback_manager.rollback(checkpoint.id)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.actions_rolled_back) == 1
|
||||
assert test_file.read_text() == "original content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_delete_new_file(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test rollback deletes file that didn't exist before."""
|
||||
test_file = temp_dir / "new_file.txt"
|
||||
assert not test_file.exists()
|
||||
|
||||
# Create checkpoint before file exists
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||
|
||||
# Create the file
|
||||
test_file.write_text("new content")
|
||||
assert test_file.exists()
|
||||
|
||||
# Rollback
|
||||
result = await rollback_manager.rollback(checkpoint.id)
|
||||
|
||||
assert result.success is True
|
||||
assert not test_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_not_found(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
) -> None:
|
||||
"""Test rollback with non-existent checkpoint."""
|
||||
with pytest.raises(RollbackError) as exc_info:
|
||||
await rollback_manager.rollback("nonexistent-id")
|
||||
|
||||
assert "not found" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_invalid_checkpoint(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test rollback with invalidated checkpoint."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original")
|
||||
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||
|
||||
# Rollback once (invalidates checkpoint)
|
||||
await rollback_manager.rollback(checkpoint.id)
|
||||
|
||||
# Try to rollback again
|
||||
with pytest.raises(RollbackError) as exc_info:
|
||||
await rollback_manager.rollback(checkpoint.id)
|
||||
|
||||
assert "no longer valid" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_checkpoint(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test discarding a checkpoint."""
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
|
||||
result = await rollback_manager.discard_checkpoint(checkpoint.id)
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
cp = await rollback_manager.get_checkpoint(checkpoint.id)
|
||||
assert cp is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discard_checkpoint_nonexistent(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
) -> None:
|
||||
"""Test discarding a non-existent checkpoint."""
|
||||
result = await rollback_manager.discard_checkpoint("nonexistent-id")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_checkpoint(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test getting a checkpoint by ID."""
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
|
||||
retrieved = await rollback_manager.get_checkpoint(checkpoint.id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == checkpoint.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_checkpoint_nonexistent(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
) -> None:
|
||||
"""Test getting a non-existent checkpoint."""
|
||||
retrieved = await rollback_manager.get_checkpoint("nonexistent-id")
|
||||
assert retrieved is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_checkpoints(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test listing checkpoints."""
|
||||
await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.create_checkpoint(action=action_request)
|
||||
|
||||
checkpoints = await rollback_manager.list_checkpoints()
|
||||
assert len(checkpoints) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_checkpoints_by_action(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_metadata: ActionMetadata,
|
||||
) -> None:
|
||||
"""Test listing checkpoints filtered by action."""
|
||||
action1 = ActionRequest(
|
||||
id="action-1",
|
||||
action_type=ActionType.FILE_WRITE,
|
||||
metadata=action_metadata,
|
||||
)
|
||||
action2 = ActionRequest(
|
||||
id="action-2",
|
||||
action_type=ActionType.FILE_WRITE,
|
||||
metadata=action_metadata,
|
||||
)
|
||||
|
||||
await rollback_manager.create_checkpoint(action=action1)
|
||||
await rollback_manager.create_checkpoint(action=action2)
|
||||
|
||||
checkpoints = await rollback_manager.list_checkpoints(action_id="action-1")
|
||||
assert len(checkpoints) == 1
|
||||
assert checkpoints[0].action_id == "action-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_checkpoints_excludes_expired(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test list_checkpoints excludes expired by default."""
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
|
||||
# Manually expire it
|
||||
async with rollback_manager._lock:
|
||||
rollback_manager._checkpoints[checkpoint.id].expires_at = (
|
||||
datetime.utcnow() - timedelta(hours=1)
|
||||
)
|
||||
|
||||
checkpoints = await rollback_manager.list_checkpoints()
|
||||
assert len(checkpoints) == 0
|
||||
|
||||
# With include_expired=True
|
||||
checkpoints = await rollback_manager.list_checkpoints(include_expired=True)
|
||||
assert len(checkpoints) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test cleaning up expired checkpoints."""
|
||||
# Create checkpoints
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("content")
|
||||
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||
|
||||
# Expire it
|
||||
async with rollback_manager._lock:
|
||||
rollback_manager._checkpoints[checkpoint.id].expires_at = (
|
||||
datetime.utcnow() - timedelta(hours=1)
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
count = await rollback_manager.cleanup_expired()
|
||||
assert count == 1
|
||||
|
||||
# Verify it's gone
|
||||
async with rollback_manager._lock:
|
||||
assert checkpoint.id not in rollback_manager._checkpoints
|
||||
assert checkpoint.id not in rollback_manager._file_checkpoints
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TransactionContext Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTransactionContext:
|
||||
"""Tests for the TransactionContext class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_creates_checkpoint(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test that entering context creates a checkpoint."""
|
||||
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||
assert tx.checkpoint_id is not None
|
||||
|
||||
# Verify checkpoint exists
|
||||
cp = await rollback_manager.get_checkpoint(tx.checkpoint_id)
|
||||
assert cp is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_checkpoint_file(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test checkpointing files through context."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original")
|
||||
|
||||
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||
await tx.checkpoint_file(str(test_file))
|
||||
|
||||
# Modify file
|
||||
test_file.write_text("modified")
|
||||
|
||||
# Manual rollback
|
||||
result = await tx.rollback()
|
||||
assert result is not None
|
||||
assert result.success is True
|
||||
|
||||
assert test_file.read_text() == "original"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_checkpoint_files(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test checkpointing multiple files through context."""
|
||||
file1 = temp_dir / "file1.txt"
|
||||
file2 = temp_dir / "file2.txt"
|
||||
file1.write_text("content 1")
|
||||
file2.write_text("content 2")
|
||||
|
||||
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||
await tx.checkpoint_files([str(file1), str(file2)])
|
||||
|
||||
cp_id = tx.checkpoint_id
|
||||
async with rollback_manager._lock:
|
||||
file_cps = rollback_manager._file_checkpoints.get(cp_id, [])
|
||||
assert len(file_cps) == 2
|
||||
|
||||
tx.commit()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_auto_rollback_on_exception(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test auto-rollback when exception occurs."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||
await tx.checkpoint_file(str(test_file))
|
||||
test_file.write_text("modified")
|
||||
raise ValueError("Simulated error")
|
||||
|
||||
# Should have been rolled back
|
||||
assert test_file.read_text() == "original"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_commit_prevents_rollback(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test that commit prevents auto-rollback."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||
await tx.checkpoint_file(str(test_file))
|
||||
test_file.write_text("modified")
|
||||
tx.commit()
|
||||
raise ValueError("Simulated error after commit")
|
||||
|
||||
# Should NOT have been rolled back
|
||||
assert test_file.read_text() == "modified"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_discards_checkpoint_on_commit(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test that checkpoint is discarded after successful commit."""
|
||||
checkpoint_id = None
|
||||
|
||||
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||
checkpoint_id = tx.checkpoint_id
|
||||
tx.commit()
|
||||
|
||||
# Checkpoint should be discarded
|
||||
cp = await rollback_manager.get_checkpoint(checkpoint_id)
|
||||
assert cp is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_no_auto_rollback_when_disabled(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test that auto_rollback=False disables auto-rollback."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with TransactionContext(
|
||||
rollback_manager,
|
||||
action_request,
|
||||
auto_rollback=False,
|
||||
) as tx:
|
||||
await tx.checkpoint_file(str(test_file))
|
||||
test_file.write_text("modified")
|
||||
raise ValueError("Simulated error")
|
||||
|
||||
# Should NOT have been rolled back
|
||||
assert test_file.read_text() == "modified"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_manual_rollback(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test manual rollback within context."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original")
|
||||
|
||||
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||
await tx.checkpoint_file(str(test_file))
|
||||
test_file.write_text("modified")
|
||||
|
||||
# Manual rollback
|
||||
result = await tx.rollback()
|
||||
assert result is not None
|
||||
assert result.success is True
|
||||
|
||||
assert test_file.read_text() == "original"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_rollback_without_checkpoint(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test rollback when checkpoint is None."""
|
||||
tx = TransactionContext(rollback_manager, action_request)
|
||||
# Don't enter context, so _checkpoint is None
|
||||
result = await tx.rollback()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_checkpoint_file_without_checkpoint(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test checkpoint_file when checkpoint is None (no-op)."""
|
||||
tx = TransactionContext(rollback_manager, action_request)
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("content")
|
||||
|
||||
# Should not raise - just a no-op
|
||||
await tx.checkpoint_file(str(test_file))
|
||||
await tx.checkpoint_files([str(test_file)])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Cases
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestRollbackEdgeCases:
|
||||
"""Edge cases that could reveal hidden bugs."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_file_for_unknown_checkpoint(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test checkpointing file for non-existent checkpoint."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("content")
|
||||
|
||||
# Should create the list if it doesn't exist
|
||||
await rollback_manager.checkpoint_file("unknown-checkpoint", str(test_file))
|
||||
|
||||
async with rollback_manager._lock:
|
||||
assert "unknown-checkpoint" in rollback_manager._file_checkpoints
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_with_partial_failure(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test rollback when some files fail to restore."""
|
||||
file1 = temp_dir / "file1.txt"
|
||||
file1.write_text("original 1")
|
||||
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.checkpoint_file(checkpoint.id, str(file1))
|
||||
|
||||
# Add a file checkpoint with a path that will fail
|
||||
async with rollback_manager._lock:
|
||||
# Create a checkpoint for a file in a non-writable location
|
||||
bad_fc = FileCheckpoint(
|
||||
checkpoint_id=checkpoint.id,
|
||||
file_path="/nonexistent/path/file.txt",
|
||||
original_content=b"content",
|
||||
existed=True,
|
||||
)
|
||||
rollback_manager._file_checkpoints[checkpoint.id].append(bad_fc)
|
||||
|
||||
# Rollback - partial failure expected
|
||||
result = await rollback_manager.rollback(checkpoint.id)
|
||||
|
||||
assert result.success is False
|
||||
assert len(result.actions_rolled_back) == 1
|
||||
assert len(result.failed_actions) == 1
|
||||
assert "Failed to rollback" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_file_creates_parent_dirs(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test that rollback creates parent directories if needed."""
|
||||
nested_file = temp_dir / "subdir" / "nested" / "file.txt"
|
||||
nested_file.parent.mkdir(parents=True)
|
||||
nested_file.write_text("original")
|
||||
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.checkpoint_file(checkpoint.id, str(nested_file))
|
||||
|
||||
# Delete the entire directory structure
|
||||
nested_file.unlink()
|
||||
(temp_dir / "subdir" / "nested").rmdir()
|
||||
(temp_dir / "subdir").rmdir()
|
||||
|
||||
# Rollback should recreate
|
||||
result = await rollback_manager.rollback(checkpoint.id)
|
||||
|
||||
assert result.success is True
|
||||
assert nested_file.exists()
|
||||
assert nested_file.read_text() == "original"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_file_already_correct(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test rollback when file already has correct content."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original")
|
||||
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||
|
||||
# Don't modify file - rollback should still succeed
|
||||
result = await rollback_manager.rollback(checkpoint.id)
|
||||
|
||||
assert result.success is True
|
||||
assert test_file.read_text() == "original"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_with_none_expires_at(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test list_checkpoints handles None expires_at."""
|
||||
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||
|
||||
# Set expires_at to None
|
||||
async with rollback_manager._lock:
|
||||
rollback_manager._checkpoints[checkpoint.id].expires_at = None
|
||||
|
||||
# Should still be listed
|
||||
checkpoints = await rollback_manager.list_checkpoints()
|
||||
assert len(checkpoints) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_rollback_failure_logged(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
temp_dir: Path,
|
||||
) -> None:
|
||||
"""Test that auto-rollback failure is logged, not raised."""
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("original")
|
||||
|
||||
with patch.object(
|
||||
rollback_manager, "rollback", side_effect=Exception("Rollback failed!")
|
||||
):
|
||||
with patch("app.services.safety.rollback.manager.logger") as mock_logger:
|
||||
with pytest.raises(ValueError):
|
||||
async with TransactionContext(
|
||||
rollback_manager, action_request
|
||||
) as tx:
|
||||
await tx.checkpoint_file(str(test_file))
|
||||
test_file.write_text("modified")
|
||||
raise ValueError("Original error")
|
||||
|
||||
# Rollback error should be logged
|
||||
mock_logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_checkpoints_same_action(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test creating multiple checkpoints for the same action."""
|
||||
cp1 = await rollback_manager.create_checkpoint(action=action_request)
|
||||
cp2 = await rollback_manager.create_checkpoint(action=action_request)
|
||||
|
||||
assert cp1.id != cp2.id
|
||||
|
||||
checkpoints = await rollback_manager.list_checkpoints(
|
||||
action_id=action_request.id
|
||||
)
|
||||
assert len(checkpoints) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_with_no_expired(
|
||||
self,
|
||||
rollback_manager: RollbackManager,
|
||||
action_request: ActionRequest,
|
||||
) -> None:
|
||||
"""Test cleanup when no checkpoints are expired."""
|
||||
await rollback_manager.create_checkpoint(action=action_request)
|
||||
|
||||
count = await rollback_manager.cleanup_expired()
|
||||
assert count == 0
|
||||
|
||||
# Checkpoint should still exist
|
||||
checkpoints = await rollback_manager.list_checkpoints()
|
||||
assert len(checkpoints) == 1
|
||||
@@ -363,6 +363,365 @@ class TestValidationBatch:
|
||||
assert results[1].decision == SafetyDecision.DENY
|
||||
|
||||
|
||||
class TestValidationCache:
|
||||
"""Tests for ValidationCache class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_miss(self) -> None:
|
||||
"""Test cache miss."""
|
||||
from app.services.safety.validation.validator import ValidationCache
|
||||
|
||||
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||
result = await cache.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_hit(self) -> None:
|
||||
"""Test cache hit."""
|
||||
from app.services.safety.models import ValidationResult
|
||||
from app.services.safety.validation.validator import ValidationCache
|
||||
|
||||
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||
vr = ValidationResult(
|
||||
action_id="action-1",
|
||||
decision=SafetyDecision.ALLOW,
|
||||
applied_rules=[],
|
||||
reasons=["test"],
|
||||
)
|
||||
await cache.set("key1", vr)
|
||||
|
||||
result = await cache.get("key1")
|
||||
assert result is not None
|
||||
assert result.action_id == "action-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_ttl_expiry(self) -> None:
|
||||
"""Test cache TTL expiry."""
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.services.safety.models import ValidationResult
|
||||
from app.services.safety.validation.validator import ValidationCache
|
||||
|
||||
cache = ValidationCache(max_size=10, ttl_seconds=1)
|
||||
vr = ValidationResult(
|
||||
action_id="action-1",
|
||||
decision=SafetyDecision.ALLOW,
|
||||
applied_rules=[],
|
||||
reasons=["test"],
|
||||
)
|
||||
await cache.set("key1", vr)
|
||||
|
||||
# Advance time past TTL
|
||||
with patch("time.time", return_value=time.time() + 2):
|
||||
result = await cache.get("key1")
|
||||
assert result is None # Should be expired
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_eviction_on_full(self) -> None:
|
||||
"""Test cache eviction when full."""
|
||||
from app.services.safety.models import ValidationResult
|
||||
from app.services.safety.validation.validator import ValidationCache
|
||||
|
||||
cache = ValidationCache(max_size=2, ttl_seconds=60)
|
||||
|
||||
vr1 = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
|
||||
vr2 = ValidationResult(action_id="a2", decision=SafetyDecision.ALLOW)
|
||||
vr3 = ValidationResult(action_id="a3", decision=SafetyDecision.ALLOW)
|
||||
|
||||
await cache.set("key1", vr1)
|
||||
await cache.set("key2", vr2)
|
||||
await cache.set("key3", vr3) # Should evict key1
|
||||
|
||||
# key1 should be evicted
|
||||
assert await cache.get("key1") is None
|
||||
assert await cache.get("key2") is not None
|
||||
assert await cache.get("key3") is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_update_existing_key(self) -> None:
|
||||
"""Test updating existing key in cache."""
|
||||
from app.services.safety.models import ValidationResult
|
||||
from app.services.safety.validation.validator import ValidationCache
|
||||
|
||||
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||
|
||||
vr1 = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
|
||||
vr2 = ValidationResult(action_id="a1-updated", decision=SafetyDecision.DENY)
|
||||
|
||||
await cache.set("key1", vr1)
|
||||
await cache.set("key1", vr2) # Should update, not add
|
||||
|
||||
result = await cache.get("key1")
|
||||
assert result is not None
|
||||
assert result.action_id == "a1" # Still old value since we move_to_end
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self) -> None:
|
||||
"""Test clearing cache."""
|
||||
from app.services.safety.models import ValidationResult
|
||||
from app.services.safety.validation.validator import ValidationCache
|
||||
|
||||
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||
|
||||
vr = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
|
||||
await cache.set("key1", vr)
|
||||
await cache.set("key2", vr)
|
||||
|
||||
await cache.clear()
|
||||
|
||||
assert await cache.get("key1") is None
|
||||
assert await cache.get("key2") is None
|
||||
|
||||
|
||||
class TestValidatorCaching:
|
||||
"""Tests for validator caching functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit(self) -> None:
|
||||
"""Test that cache is used for repeated validations."""
|
||||
validator = ActionValidator(cache_enabled=True, cache_ttl=60)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
resource="/tmp/test.txt", # noqa: S108
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# First call populates cache
|
||||
result1 = await validator.validate(action)
|
||||
# Second call should use cache
|
||||
result2 = await validator.validate(action)
|
||||
|
||||
assert result1.decision == result2.decision
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache(self) -> None:
|
||||
"""Test clearing the validation cache."""
|
||||
validator = ActionValidator(cache_enabled=True)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await validator.validate(action)
|
||||
await validator.clear_cache()
|
||||
|
||||
# Cache should be empty now (no error)
|
||||
result = await validator.validate(action)
|
||||
assert result.decision == SafetyDecision.ALLOW
|
||||
|
||||
|
||||
class TestRuleMatching:
|
||||
"""Tests for rule matching edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_type_mismatch(self) -> None:
|
||||
"""Test that rule doesn't match when action type doesn't match."""
|
||||
validator = ActionValidator(cache_enabled=False)
|
||||
validator.add_rule(
|
||||
ValidationRule(
|
||||
name="file_only",
|
||||
action_types=[ActionType.FILE_READ],
|
||||
decision=SafetyDecision.DENY,
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.SHELL_COMMAND, # Different type
|
||||
tool_name="shell_exec",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_pattern_no_tool_name(self) -> None:
|
||||
"""Test rule with tool pattern when action has no tool_name."""
|
||||
validator = ActionValidator(cache_enabled=False)
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_files",
|
||||
tool_patterns=["file_*"],
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name=None, # No tool name
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_pattern_no_resource(self) -> None:
|
||||
"""Test rule with resource pattern when action has no resource."""
|
||||
validator = ActionValidator(cache_enabled=False)
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_secrets",
|
||||
resource_patterns=["/secret/*"],
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
resource=None, # No resource
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_pattern_no_match(self) -> None:
|
||||
"""Test rule with resource pattern that doesn't match."""
|
||||
validator = ActionValidator(cache_enabled=False)
|
||||
validator.add_rule(
|
||||
create_deny_rule(
|
||||
name="deny_secrets",
|
||||
resource_patterns=["/secret/*"],
|
||||
)
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
resource="/public/file.txt", # Doesn't match
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = await validator.validate(action)
|
||||
assert result.decision == SafetyDecision.ALLOW # Pattern didn't match
|
||||
|
||||
|
||||
class TestPolicyLoading:
|
||||
"""Tests for policy loading edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_rules_from_policy_with_validation_rules(self) -> None:
|
||||
"""Test loading policy with explicit validation rules."""
|
||||
validator = ActionValidator(cache_enabled=False)
|
||||
|
||||
rule = ValidationRule(
|
||||
name="policy_rule",
|
||||
tool_patterns=["test_*"],
|
||||
decision=SafetyDecision.DENY,
|
||||
reason="From policy",
|
||||
)
|
||||
policy = SafetyPolicy(
|
||||
name="test",
|
||||
validation_rules=[rule],
|
||||
require_approval_for=[], # Clear defaults
|
||||
denied_tools=[], # Clear defaults
|
||||
)
|
||||
|
||||
validator.load_rules_from_policy(policy)
|
||||
|
||||
assert len(validator._rules) == 1
|
||||
assert validator._rules[0].name == "policy_rule"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_approval_all_pattern(self) -> None:
|
||||
"""Test loading policy with * approval pattern (all actions)."""
|
||||
validator = ActionValidator(cache_enabled=False)
|
||||
|
||||
policy = SafetyPolicy(
|
||||
name="test",
|
||||
require_approval_for=["*"], # All actions require approval
|
||||
denied_tools=[], # Clear defaults
|
||||
)
|
||||
|
||||
validator.load_rules_from_policy(policy)
|
||||
|
||||
approval_rules = [
|
||||
r for r in validator._rules if r.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
]
|
||||
assert len(approval_rules) == 1
|
||||
assert approval_rules[0].name == "require_approval_all"
|
||||
assert approval_rules[0].action_types == list(ActionType)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_with_policy_loads_rules(self) -> None:
|
||||
"""Test that validate() loads rules from policy if none exist."""
|
||||
validator = ActionValidator(cache_enabled=False)
|
||||
|
||||
policy = SafetyPolicy(
|
||||
name="test",
|
||||
denied_tools=["dangerous_*"],
|
||||
)
|
||||
|
||||
metadata = ActionMetadata(agent_id="test-agent")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.SHELL_COMMAND,
|
||||
tool_name="dangerous_exec",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Validate with policy - should load rules
|
||||
result = await validator.validate(action, policy=policy)
|
||||
|
||||
assert result.decision == SafetyDecision.DENY
|
||||
|
||||
|
||||
class TestCacheKeyGeneration:
|
||||
"""Tests for cache key generation."""
|
||||
|
||||
def test_get_cache_key(self) -> None:
|
||||
"""Test cache key generation."""
|
||||
validator = ActionValidator(cache_enabled=True)
|
||||
|
||||
metadata = ActionMetadata(
|
||||
agent_id="test-agent",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
)
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.FILE_READ,
|
||||
tool_name="file_read",
|
||||
resource="/tmp/test.txt", # noqa: S108
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
key = validator._get_cache_key(action)
|
||||
|
||||
assert "file_read" in key
|
||||
assert "file_read" in key
|
||||
assert "/tmp/test.txt" in key # noqa: S108
|
||||
assert "test-agent" in key
|
||||
assert "milestone" in key
|
||||
|
||||
def test_get_cache_key_no_resource(self) -> None:
|
||||
"""Test cache key generation without resource."""
|
||||
validator = ActionValidator(cache_enabled=True)
|
||||
|
||||
metadata = ActionMetadata(agent_id="agent-1")
|
||||
action = ActionRequest(
|
||||
action_type=ActionType.SHELL_COMMAND,
|
||||
tool_name="shell_exec",
|
||||
resource=None,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
key = validator._get_cache_key(action)
|
||||
|
||||
# Should not error with None resource
|
||||
assert "shell" in key
|
||||
assert "agent-1" in key
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Tests for rule creation helper functions."""
|
||||
|
||||
|
||||
@@ -48,6 +48,80 @@ services:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
# ==========================================================================
|
||||
# MCP Servers - Model Context Protocol servers for AI agent capabilities
|
||||
# ==========================================================================
|
||||
|
||||
mcp-llm-gateway:
|
||||
# REPLACE THIS with your actual image from your container registry
|
||||
image: YOUR_REGISTRY/YOUR_PROJECT_MCP_LLM_GATEWAY:latest
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- LLM_GATEWAY_HOST=0.0.0.0
|
||||
- LLM_GATEWAY_PORT=8001
|
||||
- REDIS_URL=redis://redis:6379/1
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- ENVIRONMENT=production
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2.0'
|
||||
memory: 2G
|
||||
reservations:
|
||||
cpus: '0.5'
|
||||
memory: 512M
|
||||
|
||||
mcp-knowledge-base:
|
||||
# REPLACE THIS with your actual image from your container registry
|
||||
image: YOUR_REGISTRY/YOUR_PROJECT_MCP_KNOWLEDGE_BASE:latest
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
# KB_ prefix required by pydantic-settings config
|
||||
- KB_HOST=0.0.0.0
|
||||
- KB_PORT=8002
|
||||
- KB_DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
|
||||
- KB_REDIS_URL=redis://redis:6379/2
|
||||
- KB_LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- ENVIRONMENT=production
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '1.0'
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 256M
|
||||
|
||||
backend:
|
||||
# REPLACE THIS with your actual image from your container registry
|
||||
# Examples:
|
||||
@@ -64,11 +138,18 @@ services:
|
||||
- DEBUG=false
|
||||
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
# MCP Server URLs
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
mcp-llm-gateway:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
@@ -92,11 +173,18 @@ services:
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
- CELERY_QUEUE=agent
|
||||
# MCP Server URLs (agents need access to MCP)
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
mcp-llm-gateway:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
@@ -32,6 +32,70 @@ services:
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
# ==========================================================================
|
||||
# MCP Servers - Model Context Protocol servers for AI agent capabilities
|
||||
# ==========================================================================
|
||||
|
||||
mcp-llm-gateway:
|
||||
build:
|
||||
context: ./mcp-servers/llm-gateway
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8001:8001"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- LLM_GATEWAY_HOST=0.0.0.0
|
||||
- LLM_GATEWAY_PORT=8001
|
||||
- REDIS_URL=redis://redis:6379/1
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- ENVIRONMENT=development
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
mcp-knowledge-base:
|
||||
build:
|
||||
context: ./mcp-servers/knowledge-base
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8002:8002"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
# KB_ prefix required by pydantic-settings config
|
||||
- KB_HOST=0.0.0.0
|
||||
- KB_PORT=8002
|
||||
- KB_DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
|
||||
- KB_REDIS_URL=redis://redis:6379/2
|
||||
- KB_LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- ENVIRONMENT=development
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
backend:
|
||||
build:
|
||||
context: ./backend
|
||||
@@ -52,11 +116,18 @@ services:
|
||||
- DEBUG=true
|
||||
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
# MCP Server URLs
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
mcp-llm-gateway:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 10s
|
||||
@@ -81,11 +152,18 @@ services:
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
- CELERY_QUEUE=agent
|
||||
# MCP Server URLs (agents need access to MCP)
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
mcp-llm-gateway:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
|
||||
|
||||
@@ -32,6 +32,82 @@ services:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
# ==========================================================================
|
||||
# MCP Servers - Model Context Protocol servers for AI agent capabilities
|
||||
# ==========================================================================
|
||||
|
||||
mcp-llm-gateway:
|
||||
build:
|
||||
context: ./mcp-servers/llm-gateway
|
||||
dockerfile: Dockerfile
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- LLM_GATEWAY_HOST=0.0.0.0
|
||||
- LLM_GATEWAY_PORT=8001
|
||||
- REDIS_URL=redis://redis:6379/1
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- ENVIRONMENT=production
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2.0'
|
||||
memory: 2G
|
||||
reservations:
|
||||
cpus: '0.5'
|
||||
memory: 512M
|
||||
|
||||
mcp-knowledge-base:
|
||||
build:
|
||||
context: ./mcp-servers/knowledge-base
|
||||
dockerfile: Dockerfile
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
# KB_ prefix required by pydantic-settings config
|
||||
- KB_HOST=0.0.0.0
|
||||
- KB_PORT=8002
|
||||
- KB_DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
|
||||
- KB_REDIS_URL=redis://redis:6379/2
|
||||
- KB_LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- ENVIRONMENT=production
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '1.0'
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 256M
|
||||
|
||||
backend:
|
||||
build:
|
||||
context: ./backend
|
||||
@@ -48,11 +124,18 @@ services:
|
||||
- DEBUG=false
|
||||
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
# MCP Server URLs
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
mcp-llm-gateway:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
@@ -75,11 +158,18 @@ services:
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
- CELERY_QUEUE=agent
|
||||
# MCP Server URLs (agents need access to MCP)
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
mcp-llm-gateway:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
@@ -205,6 +205,69 @@ test(frontend): add unit tests for ProjectDashboard
|
||||
|
||||
---
|
||||
|
||||
## Pre-Commit Hooks
|
||||
|
||||
The repository includes pre-commit hooks that enforce validation before commits on protected branches.
|
||||
|
||||
### Setup
|
||||
|
||||
Enable the hooks by configuring git to use the `.githooks` directory:
|
||||
|
||||
```bash
|
||||
git config core.hooksPath .githooks
|
||||
```
|
||||
|
||||
This only needs to be done once per clone.
|
||||
|
||||
### What the Hooks Do
|
||||
|
||||
When committing to **protected branches** (`main`, `dev`):
|
||||
|
||||
| Condition | Action |
|
||||
|-----------|--------|
|
||||
| Backend files changed | Runs `make validate` in `/backend` |
|
||||
| Frontend files changed | Runs `npm run validate` in `/frontend` |
|
||||
| No relevant changes | Skips validation |
|
||||
|
||||
If validation fails, the commit is blocked with an error message.
|
||||
|
||||
When committing to **feature branches**:
|
||||
- Validation is skipped (allows WIP commits)
|
||||
- A message reminds you to run validation manually if needed
|
||||
|
||||
### Why Protected Branches Only?
|
||||
|
||||
The hooks only enforce validation on `main` and `dev` for good reasons:
|
||||
|
||||
1. **Feature branches are for iteration** - WIP commits, experimentation, and rapid prototyping shouldn't be blocked
|
||||
2. **Flexibility during development** - You can commit broken code to your feature branch while debugging
|
||||
3. **PRs catch issues** - The merge process ensures validation passes before reaching protected branches
|
||||
4. **Manual control** - You can always run `make validate` or `npm run validate` yourself
|
||||
|
||||
### Manual Validation
|
||||
|
||||
Even on feature branches, you should validate before creating a PR:
|
||||
|
||||
```bash
|
||||
# Backend
|
||||
cd backend && make validate
|
||||
|
||||
# Frontend
|
||||
cd frontend && npm run validate
|
||||
```
|
||||
|
||||
### Bypassing Hooks (Emergency Only)
|
||||
|
||||
In rare cases where you need to bypass the hook:
|
||||
|
||||
```bash
|
||||
git commit --no-verify -m "message"
|
||||
```
|
||||
|
||||
**Use sparingly** - this defeats the purpose of the hooks.
|
||||
|
||||
---
|
||||
|
||||
## Documentation Updates
|
||||
|
||||
- Keep `docs/architecture/IMPLEMENTATION_ROADMAP.md` updated
|
||||
@@ -314,8 +377,11 @@ Do NOT use parallel agents when:
|
||||
| Action | Command/Location |
|
||||
|--------|-----------------|
|
||||
| Create branch | `git checkout -b feature/<issue>-<desc>` |
|
||||
| Enable pre-commit hooks | `git config core.hooksPath .githooks` |
|
||||
| Run backend tests | `IS_TEST=True uv run pytest` |
|
||||
| Run frontend tests | `npm test` |
|
||||
| Backend validation | `cd backend && make validate` |
|
||||
| Frontend validation | `cd frontend && npm run validate` |
|
||||
| Check types (backend) | `uv run mypy src/` |
|
||||
| Check types (frontend) | `npm run type-check` |
|
||||
| Lint (backend) | `uv run ruff check src/` |
|
||||
|
||||
@@ -386,10 +386,24 @@ describe('ActivityFeed', () => {
|
||||
});
|
||||
|
||||
it('shows event count in group header', () => {
|
||||
render(<ActivityFeed {...defaultProps} />);
|
||||
// Create fresh "today" events to avoid timezone/day boundary issues
|
||||
const todayEvents: ProjectEvent[] = [
|
||||
createMockEvent({
|
||||
id: 'today-event-1',
|
||||
type: EventType.APPROVAL_REQUESTED,
|
||||
timestamp: new Date().toISOString(),
|
||||
}),
|
||||
createMockEvent({
|
||||
id: 'today-event-2',
|
||||
type: EventType.AGENT_MESSAGE,
|
||||
timestamp: new Date().toISOString(),
|
||||
}),
|
||||
];
|
||||
|
||||
render(<ActivityFeed {...defaultProps} events={todayEvents} />);
|
||||
|
||||
const todayGroup = screen.getByTestId('event-group-today');
|
||||
// Today has 2 events in our mock data
|
||||
// Today has 2 events
|
||||
expect(within(todayGroup).getByText('2')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
79
mcp-servers/knowledge-base/Makefile
Normal file
79
mcp-servers/knowledge-base/Makefile
Normal file
@@ -0,0 +1,79 @@
|
||||
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "Knowledge Base MCP Server - Development Commands"
|
||||
@echo ""
|
||||
@echo "Setup:"
|
||||
@echo " make install - Install production dependencies"
|
||||
@echo " make install-dev - Install development dependencies"
|
||||
@echo ""
|
||||
@echo "Quality Checks:"
|
||||
@echo " make lint - Run Ruff linter"
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make type-check - Run mypy type checker"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run pytest"
|
||||
@echo " make test-cov - Run pytest with coverage"
|
||||
@echo ""
|
||||
@echo "All-in-one:"
|
||||
@echo " make validate - Run lint, type-check, and tests"
|
||||
@echo ""
|
||||
@echo "Running:"
|
||||
@echo " make run - Run the server locally"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Remove cache and build artifacts"
|
||||
|
||||
# Setup
|
||||
install:
|
||||
@echo "Installing production dependencies..."
|
||||
@uv pip install -e .
|
||||
|
||||
install-dev:
|
||||
@echo "Installing development dependencies..."
|
||||
@uv pip install -e ".[dev]"
|
||||
|
||||
# Quality checks
|
||||
lint:
|
||||
@echo "Running Ruff linter..."
|
||||
@uv run ruff check .
|
||||
|
||||
lint-fix:
|
||||
@echo "Running Ruff linter with auto-fix..."
|
||||
@uv run ruff check --fix .
|
||||
|
||||
format:
|
||||
@echo "Formatting code..."
|
||||
@uv run ruff format .
|
||||
|
||||
type-check:
|
||||
@echo "Running mypy..."
|
||||
@uv run mypy . --ignore-missing-imports
|
||||
|
||||
# Testing
|
||||
test:
|
||||
@echo "Running tests..."
|
||||
@uv run pytest tests/ -v
|
||||
|
||||
test-cov:
|
||||
@echo "Running tests with coverage..."
|
||||
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
|
||||
|
||||
# All-in-one validation
|
||||
validate: lint type-check test
|
||||
@echo "All validations passed!"
|
||||
|
||||
# Running
|
||||
run:
|
||||
@echo "Starting Knowledge Base server..."
|
||||
@uv run python server.py
|
||||
|
||||
# Cleanup
|
||||
clean:
|
||||
@echo "Cleaning up..."
|
||||
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
|
||||
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
@@ -328,7 +328,7 @@ class CollectionManager:
|
||||
"source_path": chunk.source_path or source_path,
|
||||
"start_line": chunk.start_line,
|
||||
"end_line": chunk.end_line,
|
||||
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
|
||||
"file_type": effective_file_type.value if (effective_file_type := chunk.file_type or file_type) else None,
|
||||
}
|
||||
embeddings_data.append((
|
||||
chunk.content,
|
||||
|
||||
@@ -284,41 +284,40 @@ class DatabaseManager:
|
||||
)
|
||||
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
async with self.acquire() as conn, conn.transaction():
|
||||
# Wrap in transaction for all-or-nothing batch semantics
|
||||
async with conn.transaction():
|
||||
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
||||
content_hash = self.compute_content_hash(content)
|
||||
source_path = metadata.get("source_path")
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
file_type = metadata.get("file_type")
|
||||
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
||||
content_hash = self.compute_content_hash(content)
|
||||
source_path = metadata.get("source_path")
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
file_type = metadata.get("file_type")
|
||||
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
if embedding_id:
|
||||
ids.append(str(embedding_id))
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
if embedding_id:
|
||||
ids.append(str(embedding_id))
|
||||
|
||||
logger.info(f"Stored {len(ids)} embeddings in batch")
|
||||
return ids
|
||||
@@ -566,59 +565,58 @@ class DatabaseManager:
|
||||
)
|
||||
|
||||
try:
|
||||
async with self.acquire() as conn:
|
||||
async with self.acquire() as conn, conn.transaction():
|
||||
# Use transaction for atomic replace
|
||||
async with conn.transaction():
|
||||
# First, delete existing embeddings for this source
|
||||
delete_result = await conn.execute(
|
||||
# First, delete existing embeddings for this source
|
||||
delete_result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND source_path = $2 AND collection = $3
|
||||
""",
|
||||
project_id,
|
||||
source_path,
|
||||
collection,
|
||||
)
|
||||
deleted_count = int(delete_result.split()[-1])
|
||||
|
||||
# Then insert new embeddings
|
||||
new_ids = []
|
||||
for content, embedding, chunk_type, metadata in embeddings:
|
||||
content_hash = self.compute_content_hash(content)
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
file_type = metadata.get("file_type")
|
||||
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
DELETE FROM knowledge_embeddings
|
||||
WHERE project_id = $1 AND source_path = $2 AND collection = $3
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
source_path,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
deleted_count = int(delete_result.split()[-1])
|
||||
if embedding_id:
|
||||
new_ids.append(str(embedding_id))
|
||||
|
||||
# Then insert new embeddings
|
||||
new_ids = []
|
||||
for content, embedding, chunk_type, metadata in embeddings:
|
||||
content_hash = self.compute_content_hash(content)
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
file_type = metadata.get("file_type")
|
||||
|
||||
embedding_id = await conn.fetchval(
|
||||
"""
|
||||
INSERT INTO knowledge_embeddings
|
||||
(project_id, collection, content, embedding, chunk_type,
|
||||
source_path, start_line, end_line, file_type, metadata,
|
||||
content_hash, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
RETURNING id
|
||||
""",
|
||||
project_id,
|
||||
collection,
|
||||
content,
|
||||
embedding,
|
||||
chunk_type.value,
|
||||
source_path,
|
||||
start_line,
|
||||
end_line,
|
||||
file_type,
|
||||
metadata,
|
||||
content_hash,
|
||||
expires_at,
|
||||
)
|
||||
if embedding_id:
|
||||
new_ids.append(str(embedding_id))
|
||||
|
||||
logger.info(
|
||||
f"Replaced source {source_path}: deleted {deleted_count}, "
|
||||
f"inserted {len(new_ids)} embeddings"
|
||||
)
|
||||
return deleted_count, new_ids
|
||||
logger.info(
|
||||
f"Replaced source {source_path}: deleted {deleted_count}, "
|
||||
f"inserted {len(new_ids)} embeddings"
|
||||
)
|
||||
return deleted_count, new_ids
|
||||
|
||||
except asyncpg.PostgresError as e:
|
||||
logger.error(f"Replace source error: {e}")
|
||||
|
||||
@@ -193,7 +193,7 @@ async def health_check() -> dict[str, Any]:
|
||||
# Check Redis cache (non-critical - degraded without it)
|
||||
try:
|
||||
if _embeddings and _embeddings._redis:
|
||||
await _embeddings._redis.ping()
|
||||
await _embeddings._redis.ping() # type: ignore[misc]
|
||||
status["dependencies"]["redis"] = "connected"
|
||||
else:
|
||||
status["dependencies"]["redis"] = "not initialized"
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Tests for server module and MCP tools."""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@@ -1,39 +1,25 @@
|
||||
# Syndarix LLM Gateway MCP Server
|
||||
# Multi-stage build for minimal image size
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Build stage
|
||||
FROM python:3.12-slim AS builder
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies (needed for tiktoken regex compilation)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv for fast package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
# Copy project files
|
||||
COPY pyproject.toml ./
|
||||
COPY *.py ./
|
||||
|
||||
# Create virtual environment and install dependencies
|
||||
RUN uv venv /app/.venv
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
RUN uv pip install -e .
|
||||
|
||||
# Runtime stage
|
||||
FROM python:3.12-slim AS runtime
|
||||
# Install dependencies to system Python
|
||||
RUN uv pip install --system --no-cache .
|
||||
|
||||
# Create non-root user for security
|
||||
RUN groupadd --gid 1000 appgroup && \
|
||||
useradd --uid 1000 --gid appgroup --shell /bin/bash --create-home appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy virtual environment from builder
|
||||
COPY --from=builder /app/.venv /app/.venv
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
# Copy application code
|
||||
COPY --chown=appuser:appgroup . .
|
||||
|
||||
# Switch to non-root user
|
||||
RUN useradd --create-home --shell /bin/bash appuser
|
||||
USER appuser
|
||||
|
||||
# Environment variables
|
||||
@@ -47,7 +33,7 @@ EXPOSE 8001
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()" || exit 1
|
||||
CMD python -c "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"
|
||||
|
||||
# Run the server
|
||||
CMD ["python", "server.py"]
|
||||
|
||||
79
mcp-servers/llm-gateway/Makefile
Normal file
79
mcp-servers/llm-gateway/Makefile
Normal file
@@ -0,0 +1,79 @@
|
||||
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "LLM Gateway MCP Server - Development Commands"
|
||||
@echo ""
|
||||
@echo "Setup:"
|
||||
@echo " make install - Install production dependencies"
|
||||
@echo " make install-dev - Install development dependencies"
|
||||
@echo ""
|
||||
@echo "Quality Checks:"
|
||||
@echo " make lint - Run Ruff linter"
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make type-check - Run mypy type checker"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run pytest"
|
||||
@echo " make test-cov - Run pytest with coverage"
|
||||
@echo ""
|
||||
@echo "All-in-one:"
|
||||
@echo " make validate - Run lint, type-check, and tests"
|
||||
@echo ""
|
||||
@echo "Running:"
|
||||
@echo " make run - Run the server locally"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Remove cache and build artifacts"
|
||||
|
||||
# Setup
|
||||
install:
|
||||
@echo "Installing production dependencies..."
|
||||
@uv pip install -e .
|
||||
|
||||
install-dev:
|
||||
@echo "Installing development dependencies..."
|
||||
@uv pip install -e ".[dev]"
|
||||
|
||||
# Quality checks
|
||||
lint:
|
||||
@echo "Running Ruff linter..."
|
||||
@uv run ruff check .
|
||||
|
||||
lint-fix:
|
||||
@echo "Running Ruff linter with auto-fix..."
|
||||
@uv run ruff check --fix .
|
||||
|
||||
format:
|
||||
@echo "Formatting code..."
|
||||
@uv run ruff format .
|
||||
|
||||
type-check:
|
||||
@echo "Running mypy..."
|
||||
@uv run mypy . --ignore-missing-imports
|
||||
|
||||
# Testing
|
||||
test:
|
||||
@echo "Running tests..."
|
||||
@uv run pytest tests/ -v
|
||||
|
||||
test-cov:
|
||||
@echo "Running tests with coverage..."
|
||||
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
|
||||
|
||||
# All-in-one validation
|
||||
validate: lint type-check test
|
||||
@echo "All validations passed!"
|
||||
|
||||
# Running
|
||||
run:
|
||||
@echo "Starting LLM Gateway server..."
|
||||
@uv run python server.py
|
||||
|
||||
# Cleanup
|
||||
clean:
|
||||
@echo "Cleaning up..."
|
||||
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
|
||||
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
@@ -110,14 +110,13 @@ class CircuitBreaker:
|
||||
"""
|
||||
if self._state == CircuitState.OPEN:
|
||||
time_in_open = time.time() - self._stats.state_changed_at
|
||||
if time_in_open >= self.recovery_timeout:
|
||||
# Only transition if still in OPEN state (double-check)
|
||||
if self._state == CircuitState.OPEN:
|
||||
self._transition_to(CircuitState.HALF_OPEN)
|
||||
logger.info(
|
||||
f"Circuit {self.name} transitioned to HALF_OPEN "
|
||||
f"after {time_in_open:.1f}s"
|
||||
)
|
||||
# Double-check state after time calculation (for thread safety)
|
||||
if time_in_open >= self.recovery_timeout and self._state == CircuitState.OPEN:
|
||||
self._transition_to(CircuitState.HALF_OPEN)
|
||||
logger.info(
|
||||
f"Circuit {self.name} transitioned to HALF_OPEN "
|
||||
f"after {time_in_open:.1f}s"
|
||||
)
|
||||
|
||||
def _transition_to(self, new_state: CircuitState) -> None:
|
||||
"""Transition to a new state."""
|
||||
|
||||
Reference in New Issue
Block a user