forked from cardosofelipe/pragma-stack
Compare commits
14 Commits
2bea057fb1
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b149b8a52 | ||
|
|
ad0c06851d | ||
|
|
49359b1416 | ||
|
|
911d950c15 | ||
|
|
b2a3ac60e0 | ||
|
|
dea092e1bb | ||
|
|
4154dd5268 | ||
|
|
db12937495 | ||
|
|
81e1456631 | ||
|
|
58e78d8700 | ||
|
|
5e80139afa | ||
|
|
60ebeaa582 | ||
|
|
758052dcff | ||
|
|
1628eacf2b |
61
.githooks/pre-commit
Executable file
61
.githooks/pre-commit
Executable file
@@ -0,0 +1,61 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Pre-commit hook to enforce validation before commits on protected branches
|
||||||
|
# Install: git config core.hooksPath .githooks
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Get the current branch name
|
||||||
|
BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||||
|
|
||||||
|
# Protected branches that require validation
|
||||||
|
PROTECTED_BRANCHES="main dev"
|
||||||
|
|
||||||
|
# Check if we're on a protected branch
|
||||||
|
is_protected() {
|
||||||
|
for branch in $PROTECTED_BRANCHES; do
|
||||||
|
if [ "$BRANCH" = "$branch" ]; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_protected; then
|
||||||
|
echo "🔒 Committing to protected branch '$BRANCH' - running validation..."
|
||||||
|
|
||||||
|
# Check if we have backend changes
|
||||||
|
if git diff --cached --name-only | grep -q "^backend/"; then
|
||||||
|
echo "📦 Backend changes detected - running make validate..."
|
||||||
|
cd backend
|
||||||
|
if ! make validate; then
|
||||||
|
echo ""
|
||||||
|
echo "❌ Backend validation failed!"
|
||||||
|
echo " Please fix the issues and try again."
|
||||||
|
echo " Run 'cd backend && make validate' to see errors."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd ..
|
||||||
|
echo "✅ Backend validation passed!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if we have frontend changes
|
||||||
|
if git diff --cached --name-only | grep -q "^frontend/"; then
|
||||||
|
echo "🎨 Frontend changes detected - running npm run validate..."
|
||||||
|
cd frontend
|
||||||
|
if ! npm run validate 2>/dev/null; then
|
||||||
|
echo ""
|
||||||
|
echo "❌ Frontend validation failed!"
|
||||||
|
echo " Please fix the issues and try again."
|
||||||
|
echo " Run 'cd frontend && npm run validate' to see errors."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd ..
|
||||||
|
echo "✅ Frontend validation passed!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "🎉 All validations passed! Proceeding with commit..."
|
||||||
|
else
|
||||||
|
echo "📝 Committing to feature branch '$BRANCH' - skipping validation (run manually if needed)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
||||||
88
Makefile
88
Makefile
@@ -1,18 +1,31 @@
|
|||||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||||
|
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
|
||||||
|
|
||||||
VERSION ?= latest
|
VERSION ?= latest
|
||||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||||
|
|
||||||
# Default target
|
# Default target
|
||||||
help:
|
help:
|
||||||
@echo "FastAPI + Next.js Full-Stack Template"
|
@echo "Syndarix - AI-Powered Software Consulting Agency"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Development:"
|
@echo "Development:"
|
||||||
@echo " make dev - Start backend + db (frontend runs separately)"
|
@echo " make dev - Start backend + db + MCP servers (frontend runs separately)"
|
||||||
@echo " make dev-full - Start all services including frontend"
|
@echo " make dev-full - Start all services including frontend"
|
||||||
@echo " make down - Stop all services"
|
@echo " make down - Stop all services"
|
||||||
@echo " make logs-dev - Follow dev container logs"
|
@echo " make logs-dev - Follow dev container logs"
|
||||||
@echo ""
|
@echo ""
|
||||||
|
@echo "Testing:"
|
||||||
|
@echo " make test - Run all tests (backend + MCP servers)"
|
||||||
|
@echo " make test-backend - Run backend tests only"
|
||||||
|
@echo " make test-mcp - Run MCP server tests only"
|
||||||
|
@echo " make test-frontend - Run frontend tests only"
|
||||||
|
@echo " make test-cov - Run all tests with coverage reports"
|
||||||
|
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||||
|
@echo ""
|
||||||
|
@echo "Validation:"
|
||||||
|
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
|
||||||
|
@echo " make validate-all - Validate everything including frontend"
|
||||||
|
@echo ""
|
||||||
@echo "Database:"
|
@echo "Database:"
|
||||||
@echo " make drop-db - Drop and recreate empty database"
|
@echo " make drop-db - Drop and recreate empty database"
|
||||||
@echo " make reset-db - Drop database and apply all migrations"
|
@echo " make reset-db - Drop database and apply all migrations"
|
||||||
@@ -29,6 +42,8 @@ help:
|
|||||||
@echo ""
|
@echo ""
|
||||||
@echo "Subdirectory commands:"
|
@echo "Subdirectory commands:"
|
||||||
@echo " cd backend && make help - Backend-specific commands"
|
@echo " cd backend && make help - Backend-specific commands"
|
||||||
|
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
||||||
|
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
||||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -99,3 +114,72 @@ clean:
|
|||||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||||
clean-slate:
|
clean-slate:
|
||||||
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Testing
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
test: test-backend test-mcp
|
||||||
|
@echo ""
|
||||||
|
@echo "All tests passed!"
|
||||||
|
|
||||||
|
test-backend:
|
||||||
|
@echo "Running backend tests..."
|
||||||
|
@cd backend && IS_TEST=True uv run pytest tests/ -v
|
||||||
|
|
||||||
|
test-mcp:
|
||||||
|
@echo "Running MCP server tests..."
|
||||||
|
@echo ""
|
||||||
|
@echo "=== LLM Gateway ==="
|
||||||
|
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Knowledge Base ==="
|
||||||
|
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
||||||
|
|
||||||
|
test-frontend:
|
||||||
|
@echo "Running frontend tests..."
|
||||||
|
@cd frontend && npm test
|
||||||
|
|
||||||
|
test-all: test test-frontend
|
||||||
|
@echo ""
|
||||||
|
@echo "All tests (backend + MCP + frontend) passed!"
|
||||||
|
|
||||||
|
test-cov:
|
||||||
|
@echo "Running all tests with coverage..."
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Backend Coverage ==="
|
||||||
|
@cd backend && IS_TEST=True uv run pytest tests/ -v --cov=app --cov-report=term-missing
|
||||||
|
@echo ""
|
||||||
|
@echo "=== LLM Gateway Coverage ==="
|
||||||
|
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Knowledge Base Coverage ==="
|
||||||
|
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||||
|
|
||||||
|
test-integration:
|
||||||
|
@echo "Running MCP integration tests..."
|
||||||
|
@echo "Note: Requires running stack (make dev first)"
|
||||||
|
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Validation (lint + type-check + test)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
validate:
|
||||||
|
@echo "Validating backend..."
|
||||||
|
@cd backend && make validate
|
||||||
|
@echo ""
|
||||||
|
@echo "Validating LLM Gateway..."
|
||||||
|
@cd mcp-servers/llm-gateway && make validate
|
||||||
|
@echo ""
|
||||||
|
@echo "Validating Knowledge Base..."
|
||||||
|
@cd mcp-servers/knowledge-base && make validate
|
||||||
|
@echo ""
|
||||||
|
@echo "All validations passed!"
|
||||||
|
|
||||||
|
validate-all: validate
|
||||||
|
@echo ""
|
||||||
|
@echo "Validating frontend..."
|
||||||
|
@cd frontend && npm run validate
|
||||||
|
@echo ""
|
||||||
|
@echo "Full validation passed!"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
|
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all test-integration
|
||||||
|
|
||||||
# Default target
|
# Default target
|
||||||
help:
|
help:
|
||||||
@@ -22,6 +22,7 @@ help:
|
|||||||
@echo " make test-cov - Run pytest with coverage report"
|
@echo " make test-cov - Run pytest with coverage report"
|
||||||
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
||||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||||
|
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||||
@echo " make test-all - Run all tests (unit + E2E)"
|
@echo " make test-all - Run all tests (unit + E2E)"
|
||||||
@echo " make check-docker - Check if Docker is available"
|
@echo " make check-docker - Check if Docker is available"
|
||||||
@echo ""
|
@echo ""
|
||||||
@@ -82,6 +83,15 @@ test-cov:
|
|||||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Integration Testing (requires running stack: make dev)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
test-integration:
|
||||||
|
@echo "🧪 Running MCP integration tests..."
|
||||||
|
@echo "Note: Requires running stack (make dev from project root)"
|
||||||
|
@RUN_INTEGRATION_TESTS=true IS_TEST=True PYTHONPATH=. uv run pytest tests/integration/ -v
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# E2E Testing (requires Docker)
|
# E2E Testing (requires Docker)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from app.api.routes import (
|
|||||||
agent_types,
|
agent_types,
|
||||||
agents,
|
agents,
|
||||||
auth,
|
auth,
|
||||||
|
context,
|
||||||
events,
|
events,
|
||||||
issues,
|
issues,
|
||||||
mcp,
|
mcp,
|
||||||
@@ -35,6 +36,9 @@ api_router.include_router(events.router, tags=["Events"])
|
|||||||
# MCP (Model Context Protocol) router
|
# MCP (Model Context Protocol) router
|
||||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
|
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
|
||||||
|
|
||||||
|
# Context Management Engine router
|
||||||
|
api_router.include_router(context.router, prefix="/context", tags=["Context"])
|
||||||
|
|
||||||
# Syndarix domain routers
|
# Syndarix domain routers
|
||||||
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
||||||
api_router.include_router(
|
api_router.include_router(
|
||||||
|
|||||||
411
backend/app/api/routes/context.py
Normal file
411
backend/app/api/routes/context.py
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
"""
|
||||||
|
Context Management API Endpoints.
|
||||||
|
|
||||||
|
Provides REST endpoints for context assembly and optimization
|
||||||
|
for LLM requests using the ContextEngine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.api.dependencies.permissions import require_superuser
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.context import (
|
||||||
|
AssemblyTimeoutError,
|
||||||
|
BudgetExceededError,
|
||||||
|
ContextEngine,
|
||||||
|
ContextSettings,
|
||||||
|
create_context_engine,
|
||||||
|
get_context_settings,
|
||||||
|
)
|
||||||
|
from app.services.mcp import MCPClientManager, get_mcp_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Singleton Engine Management
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
_context_engine: ContextEngine | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_create_engine(
|
||||||
|
mcp: MCPClientManager,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> ContextEngine:
|
||||||
|
"""Get or create the singleton ContextEngine."""
|
||||||
|
global _context_engine
|
||||||
|
if _context_engine is None:
|
||||||
|
_context_engine = create_context_engine(
|
||||||
|
mcp_manager=mcp,
|
||||||
|
redis=None, # Optional: add Redis caching later
|
||||||
|
settings=settings or get_context_settings(),
|
||||||
|
)
|
||||||
|
logger.info("ContextEngine initialized")
|
||||||
|
else:
|
||||||
|
# Ensure MCP manager is up to date
|
||||||
|
_context_engine.set_mcp_manager(mcp)
|
||||||
|
return _context_engine
|
||||||
|
|
||||||
|
|
||||||
|
async def get_context_engine(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ContextEngine:
|
||||||
|
"""FastAPI dependency to get the ContextEngine."""
|
||||||
|
return _get_or_create_engine(mcp)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Request/Response Schemas
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationTurn(BaseModel):
|
||||||
|
"""A single conversation turn."""
|
||||||
|
|
||||||
|
role: str = Field(..., description="Role: 'user' or 'assistant'")
|
||||||
|
content: str = Field(..., description="Message content")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResult(BaseModel):
|
||||||
|
"""A tool execution result."""
|
||||||
|
|
||||||
|
tool_name: str = Field(..., description="Name of the tool")
|
||||||
|
content: str | dict[str, Any] = Field(..., description="Tool result content")
|
||||||
|
status: str = Field(default="success", description="Execution status")
|
||||||
|
|
||||||
|
|
||||||
|
class AssembleContextRequest(BaseModel):
|
||||||
|
"""Request to assemble context for an LLM request."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project identifier")
|
||||||
|
agent_id: str = Field(..., description="Agent identifier")
|
||||||
|
query: str = Field(..., description="User's query or current request")
|
||||||
|
model: str = Field(
|
||||||
|
default="claude-3-sonnet",
|
||||||
|
description="Target model name",
|
||||||
|
)
|
||||||
|
max_tokens: int | None = Field(
|
||||||
|
None,
|
||||||
|
description="Maximum context tokens (uses model default if None)",
|
||||||
|
)
|
||||||
|
system_prompt: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="System prompt/instructions",
|
||||||
|
)
|
||||||
|
task_description: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Current task description",
|
||||||
|
)
|
||||||
|
knowledge_query: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Query for knowledge base search",
|
||||||
|
)
|
||||||
|
knowledge_limit: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=50,
|
||||||
|
description="Max number of knowledge results",
|
||||||
|
)
|
||||||
|
conversation_history: list[ConversationTurn] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Previous conversation turns",
|
||||||
|
)
|
||||||
|
tool_results: list[ToolResult] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Tool execution results to include",
|
||||||
|
)
|
||||||
|
compress: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether to apply compression",
|
||||||
|
)
|
||||||
|
use_cache: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether to use caching",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AssembledContextResponse(BaseModel):
|
||||||
|
"""Response containing assembled context."""
|
||||||
|
|
||||||
|
content: str = Field(..., description="Assembled context content")
|
||||||
|
total_tokens: int = Field(..., description="Total token count")
|
||||||
|
context_count: int = Field(..., description="Number of context items included")
|
||||||
|
compressed: bool = Field(..., description="Whether compression was applied")
|
||||||
|
budget_used_percent: float = Field(
|
||||||
|
...,
|
||||||
|
description="Percentage of token budget used",
|
||||||
|
)
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Additional metadata",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCountRequest(BaseModel):
|
||||||
|
"""Request to count tokens in content."""
|
||||||
|
|
||||||
|
content: str = Field(..., description="Content to count tokens in")
|
||||||
|
model: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Model for model-specific tokenization",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCountResponse(BaseModel):
|
||||||
|
"""Response containing token count."""
|
||||||
|
|
||||||
|
token_count: int = Field(..., description="Number of tokens")
|
||||||
|
model: str | None = Field(None, description="Model used for counting")
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetInfoResponse(BaseModel):
|
||||||
|
"""Response containing budget information for a model."""
|
||||||
|
|
||||||
|
model: str = Field(..., description="Model name")
|
||||||
|
total_tokens: int = Field(..., description="Total token budget")
|
||||||
|
system_tokens: int = Field(..., description="Tokens reserved for system")
|
||||||
|
knowledge_tokens: int = Field(..., description="Tokens for knowledge")
|
||||||
|
conversation_tokens: int = Field(..., description="Tokens for conversation")
|
||||||
|
tool_tokens: int = Field(..., description="Tokens for tool results")
|
||||||
|
response_reserve: int = Field(..., description="Tokens reserved for response")
|
||||||
|
|
||||||
|
|
||||||
|
class ContextEngineStatsResponse(BaseModel):
|
||||||
|
"""Response containing engine statistics."""
|
||||||
|
|
||||||
|
cache: dict[str, Any] = Field(..., description="Cache statistics")
|
||||||
|
settings: dict[str, Any] = Field(..., description="Current settings")
|
||||||
|
|
||||||
|
|
||||||
|
class HealthResponse(BaseModel):
|
||||||
|
"""Health check response."""
|
||||||
|
|
||||||
|
status: str = Field(..., description="Health status")
|
||||||
|
mcp_connected: bool = Field(..., description="Whether MCP is connected")
|
||||||
|
cache_enabled: bool = Field(..., description="Whether caching is enabled")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Endpoints
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/health",
|
||||||
|
response_model=HealthResponse,
|
||||||
|
summary="Context Engine Health",
|
||||||
|
description="Check health status of the context engine.",
|
||||||
|
)
|
||||||
|
async def health_check(
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> HealthResponse:
|
||||||
|
"""Check context engine health."""
|
||||||
|
stats = await engine.get_stats()
|
||||||
|
return HealthResponse(
|
||||||
|
status="healthy",
|
||||||
|
mcp_connected=engine._mcp is not None,
|
||||||
|
cache_enabled=stats.get("settings", {}).get("cache_enabled", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/assemble",
|
||||||
|
response_model=AssembledContextResponse,
|
||||||
|
summary="Assemble Context",
|
||||||
|
description="Assemble optimized context for an LLM request.",
|
||||||
|
)
|
||||||
|
async def assemble_context(
|
||||||
|
request: AssembleContextRequest,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> AssembledContextResponse:
|
||||||
|
"""
|
||||||
|
Assemble optimized context for an LLM request.
|
||||||
|
|
||||||
|
This endpoint gathers context from various sources, scores and ranks them,
|
||||||
|
compresses if needed, and formats for the target model.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Context assembly for project=%s agent=%s by user=%s",
|
||||||
|
request.project_id,
|
||||||
|
request.agent_id,
|
||||||
|
current_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert conversation history to dict format
|
||||||
|
conversation_history = None
|
||||||
|
if request.conversation_history:
|
||||||
|
conversation_history = [
|
||||||
|
{"role": turn.role, "content": turn.content}
|
||||||
|
for turn in request.conversation_history
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert tool results to dict format
|
||||||
|
tool_results = None
|
||||||
|
if request.tool_results:
|
||||||
|
tool_results = [
|
||||||
|
{
|
||||||
|
"tool_name": tr.tool_name,
|
||||||
|
"content": tr.content,
|
||||||
|
"status": tr.status,
|
||||||
|
}
|
||||||
|
for tr in request.tool_results
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id=request.project_id,
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
query=request.query,
|
||||||
|
model=request.model,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
system_prompt=request.system_prompt,
|
||||||
|
task_description=request.task_description,
|
||||||
|
knowledge_query=request.knowledge_query,
|
||||||
|
knowledge_limit=request.knowledge_limit,
|
||||||
|
conversation_history=conversation_history,
|
||||||
|
tool_results=tool_results,
|
||||||
|
compress=request.compress,
|
||||||
|
use_cache=request.use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate budget usage percentage
|
||||||
|
budget = await engine.get_budget_for_model(request.model, request.max_tokens)
|
||||||
|
budget_used_percent = (result.total_tokens / budget.total) * 100
|
||||||
|
|
||||||
|
# Check if compression was applied (from metadata if available)
|
||||||
|
was_compressed = result.metadata.get("compressed_contexts", 0) > 0
|
||||||
|
|
||||||
|
return AssembledContextResponse(
|
||||||
|
content=result.content,
|
||||||
|
total_tokens=result.total_tokens,
|
||||||
|
context_count=result.context_count,
|
||||||
|
compressed=was_compressed,
|
||||||
|
budget_used_percent=round(budget_used_percent, 2),
|
||||||
|
metadata={
|
||||||
|
"model": request.model,
|
||||||
|
"query": request.query,
|
||||||
|
"knowledge_included": bool(request.knowledge_query),
|
||||||
|
"conversation_turns": len(request.conversation_history or []),
|
||||||
|
"excluded_count": result.excluded_count,
|
||||||
|
"assembly_time_ms": result.assembly_time_ms,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except AssemblyTimeoutError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
|
detail=f"Context assembly timed out: {e}",
|
||||||
|
) from e
|
||||||
|
except BudgetExceededError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||||
|
detail=f"Token budget exceeded: {e}",
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Context assembly failed")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Context assembly failed: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/count-tokens",
|
||||||
|
response_model=TokenCountResponse,
|
||||||
|
summary="Count Tokens",
|
||||||
|
description="Count tokens in content using the LLM Gateway.",
|
||||||
|
)
|
||||||
|
async def count_tokens(
|
||||||
|
request: TokenCountRequest,
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> TokenCountResponse:
|
||||||
|
"""Count tokens in content."""
|
||||||
|
try:
|
||||||
|
count = await engine.count_tokens(
|
||||||
|
content=request.content,
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
return TokenCountResponse(
|
||||||
|
token_count=count,
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Token counting failed: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Token counting failed: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/budget/{model}",
|
||||||
|
response_model=BudgetInfoResponse,
|
||||||
|
summary="Get Token Budget",
|
||||||
|
description="Get token budget allocation for a specific model.",
|
||||||
|
)
|
||||||
|
async def get_budget(
|
||||||
|
model: str,
|
||||||
|
max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None,
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> BudgetInfoResponse:
|
||||||
|
"""Get token budget information for a model."""
|
||||||
|
budget = await engine.get_budget_for_model(model, max_tokens)
|
||||||
|
return BudgetInfoResponse(
|
||||||
|
model=model,
|
||||||
|
total_tokens=budget.total,
|
||||||
|
system_tokens=budget.system,
|
||||||
|
knowledge_tokens=budget.knowledge,
|
||||||
|
conversation_tokens=budget.conversation,
|
||||||
|
tool_tokens=budget.tools,
|
||||||
|
response_reserve=budget.response_reserve,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/stats",
|
||||||
|
response_model=ContextEngineStatsResponse,
|
||||||
|
summary="Engine Statistics",
|
||||||
|
description="Get context engine statistics and configuration.",
|
||||||
|
)
|
||||||
|
async def get_stats(
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> ContextEngineStatsResponse:
|
||||||
|
"""Get engine statistics."""
|
||||||
|
stats = await engine.get_stats()
|
||||||
|
return ContextEngineStatsResponse(
|
||||||
|
cache=stats.get("cache", {}),
|
||||||
|
settings=stats.get("settings", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/cache/invalidate",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Invalidate Cache (Admin Only)",
|
||||||
|
description="Invalidate context cache entries.",
|
||||||
|
)
|
||||||
|
async def invalidate_cache(
|
||||||
|
project_id: Annotated[
|
||||||
|
str | None, Query(description="Project to invalidate")
|
||||||
|
] = None,
|
||||||
|
pattern: Annotated[str | None, Query(description="Pattern to match")] = None,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> None:
|
||||||
|
"""Invalidate cache entries."""
|
||||||
|
logger.info(
|
||||||
|
"Cache invalidation by user %s: project=%s pattern=%s",
|
||||||
|
current_user.id,
|
||||||
|
project_id,
|
||||||
|
pattern,
|
||||||
|
)
|
||||||
|
await engine.invalidate_cache(project_id=project_id, pattern=pattern)
|
||||||
@@ -90,16 +90,19 @@ class ClaudeAdapter(ModelAdapter):
|
|||||||
elif context_type == ContextType.TOOL:
|
elif context_type == ContextType.TOOL:
|
||||||
return self._format_tool(contexts)
|
return self._format_tool(contexts)
|
||||||
|
|
||||||
return "\n".join(c.content for c in contexts)
|
# Fallback for any unhandled context types - still escape content
|
||||||
|
# to prevent XML injection if new types are added without updating adapter
|
||||||
|
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
|
|
||||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||||
"""Format system contexts."""
|
"""Format system contexts."""
|
||||||
content = "\n\n".join(c.content for c in contexts)
|
# System prompts are typically admin-controlled, but escape for safety
|
||||||
|
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||||
|
|
||||||
def _format_task(self, contexts: list[BaseContext]) -> str:
|
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||||
"""Format task contexts."""
|
"""Format task contexts."""
|
||||||
content = "\n\n".join(c.content for c in contexts)
|
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
return f"<current_task>\n{content}\n</current_task>"
|
return f"<current_task>\n{content}\n</current_task>"
|
||||||
|
|
||||||
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||||
@@ -107,16 +110,22 @@ class ClaudeAdapter(ModelAdapter):
|
|||||||
Format knowledge contexts as structured documents.
|
Format knowledge contexts as structured documents.
|
||||||
|
|
||||||
Each knowledge context becomes a document with source attribution.
|
Each knowledge context becomes a document with source attribution.
|
||||||
|
All content is XML-escaped to prevent injection attacks.
|
||||||
"""
|
"""
|
||||||
parts = ["<reference_documents>"]
|
parts = ["<reference_documents>"]
|
||||||
|
|
||||||
for ctx in contexts:
|
for ctx in contexts:
|
||||||
source = self._escape_xml(ctx.source)
|
source = self._escape_xml(ctx.source)
|
||||||
content = ctx.content
|
# Escape content to prevent XML injection
|
||||||
|
content = self._escape_xml_content(ctx.content)
|
||||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||||
|
|
||||||
if score:
|
if score:
|
||||||
parts.append(f'<document source="{source}" relevance="{score}">')
|
# Escape score to prevent XML injection via metadata
|
||||||
|
escaped_score = self._escape_xml(str(score))
|
||||||
|
parts.append(
|
||||||
|
f'<document source="{source}" relevance="{escaped_score}">'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
parts.append(f'<document source="{source}">')
|
parts.append(f'<document source="{source}">')
|
||||||
|
|
||||||
@@ -131,13 +140,16 @@ class ClaudeAdapter(ModelAdapter):
|
|||||||
Format conversation contexts as message history.
|
Format conversation contexts as message history.
|
||||||
|
|
||||||
Uses role-based message tags for clear turn delineation.
|
Uses role-based message tags for clear turn delineation.
|
||||||
|
All content is XML-escaped to prevent prompt injection.
|
||||||
"""
|
"""
|
||||||
parts = ["<conversation_history>"]
|
parts = ["<conversation_history>"]
|
||||||
|
|
||||||
for ctx in contexts:
|
for ctx in contexts:
|
||||||
role = ctx.metadata.get("role", "user")
|
role = self._escape_xml(ctx.metadata.get("role", "user"))
|
||||||
|
# Escape content to prevent prompt injection via fake XML tags
|
||||||
|
content = self._escape_xml_content(ctx.content)
|
||||||
parts.append(f'<message role="{role}">')
|
parts.append(f'<message role="{role}">')
|
||||||
parts.append(ctx.content)
|
parts.append(content)
|
||||||
parts.append("</message>")
|
parts.append("</message>")
|
||||||
|
|
||||||
parts.append("</conversation_history>")
|
parts.append("</conversation_history>")
|
||||||
@@ -148,19 +160,23 @@ class ClaudeAdapter(ModelAdapter):
|
|||||||
Format tool contexts as tool results.
|
Format tool contexts as tool results.
|
||||||
|
|
||||||
Each tool result is wrapped with the tool name.
|
Each tool result is wrapped with the tool name.
|
||||||
|
All content is XML-escaped to prevent injection.
|
||||||
"""
|
"""
|
||||||
parts = ["<tool_results>"]
|
parts = ["<tool_results>"]
|
||||||
|
|
||||||
for ctx in contexts:
|
for ctx in contexts:
|
||||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
|
||||||
status = ctx.metadata.get("status", "")
|
status = ctx.metadata.get("status", "")
|
||||||
|
|
||||||
if status:
|
if status:
|
||||||
parts.append(f'<tool_result name="{tool_name}" status="{status}">')
|
parts.append(
|
||||||
|
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
parts.append(f'<tool_result name="{tool_name}">')
|
parts.append(f'<tool_result name="{tool_name}">')
|
||||||
|
|
||||||
parts.append(ctx.content)
|
# Escape content to prevent injection
|
||||||
|
parts.append(self._escape_xml_content(ctx.content))
|
||||||
parts.append("</tool_result>")
|
parts.append("</tool_result>")
|
||||||
|
|
||||||
parts.append("</tool_results>")
|
parts.append("</tool_results>")
|
||||||
@@ -176,3 +192,21 @@ class ClaudeAdapter(ModelAdapter):
|
|||||||
.replace('"', """)
|
.replace('"', """)
|
||||||
.replace("'", "'")
|
.replace("'", "'")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _escape_xml_content(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Escape XML special characters in element content.
|
||||||
|
|
||||||
|
This prevents XML injection attacks where malicious content
|
||||||
|
could break out of XML tags or inject fake tags for prompt injection.
|
||||||
|
|
||||||
|
Only escapes &, <, > since quotes don't need escaping in content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Content text to escape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
XML-safe content string
|
||||||
|
"""
|
||||||
|
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from dataclasses import dataclass, field
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..adapters import get_adapter
|
||||||
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||||
from ..compression.truncation import ContextCompressor
|
from ..compression.truncation import ContextCompressor
|
||||||
from ..config import ContextSettings, get_context_settings
|
from ..config import ContextSettings, get_context_settings
|
||||||
@@ -156,19 +157,41 @@ class ContextPipeline:
|
|||||||
else:
|
else:
|
||||||
budget = self._allocator.create_budget_for_model(model)
|
budget = self._allocator.create_budget_for_model(model)
|
||||||
|
|
||||||
# 1. Count tokens for all contexts
|
# 1. Count tokens for all contexts (with timeout enforcement)
|
||||||
await self._ensure_token_counts(contexts, model)
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._ensure_token_counts(contexts, model),
|
||||||
|
timeout=self._remaining_timeout(start, timeout),
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message="Context assembly timed out during token counting",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
# Check timeout
|
# Check timeout (handles edge case where operation finished just at limit)
|
||||||
self._check_timeout(start, timeout, "token counting")
|
self._check_timeout(start, timeout, "token counting")
|
||||||
|
|
||||||
# 2. Score and rank contexts
|
# 2. Score and rank contexts (with timeout enforcement)
|
||||||
scoring_start = time.perf_counter()
|
scoring_start = time.perf_counter()
|
||||||
ranking_result = await self._ranker.rank(
|
try:
|
||||||
|
ranking_result = await asyncio.wait_for(
|
||||||
|
self._ranker.rank(
|
||||||
contexts=contexts,
|
contexts=contexts,
|
||||||
query=query,
|
query=query,
|
||||||
budget=budget,
|
budget=budget,
|
||||||
model=model,
|
model=model,
|
||||||
|
),
|
||||||
|
timeout=self._remaining_timeout(start, timeout),
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message="Context assembly timed out during scoring/ranking",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout,
|
||||||
)
|
)
|
||||||
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||||
|
|
||||||
@@ -179,11 +202,22 @@ class ContextPipeline:
|
|||||||
# Check timeout
|
# Check timeout
|
||||||
self._check_timeout(start, timeout, "scoring")
|
self._check_timeout(start, timeout, "scoring")
|
||||||
|
|
||||||
# 3. Compress if needed and enabled
|
# 3. Compress if needed and enabled (with timeout enforcement)
|
||||||
if compress and self._needs_compression(selected_contexts, budget):
|
if compress and self._needs_compression(selected_contexts, budget):
|
||||||
compression_start = time.perf_counter()
|
compression_start = time.perf_counter()
|
||||||
selected_contexts = await self._compressor.compress_contexts(
|
try:
|
||||||
|
selected_contexts = await asyncio.wait_for(
|
||||||
|
self._compressor.compress_contexts(
|
||||||
selected_contexts, budget, model
|
selected_contexts, budget, model
|
||||||
|
),
|
||||||
|
timeout=self._remaining_timeout(start, timeout),
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message="Context assembly timed out during compression",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout,
|
||||||
)
|
)
|
||||||
metrics.compression_time_ms = (
|
metrics.compression_time_ms = (
|
||||||
time.perf_counter() - compression_start
|
time.perf_counter() - compression_start
|
||||||
@@ -280,129 +314,18 @@ class ContextPipeline:
|
|||||||
"""
|
"""
|
||||||
Format contexts for the target model.
|
Format contexts for the target model.
|
||||||
|
|
||||||
Groups contexts by type and applies model-specific formatting.
|
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
||||||
|
to format contexts optimally for each model family.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to format
|
||||||
|
model: Target model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string
|
||||||
"""
|
"""
|
||||||
# Group by type
|
adapter = get_adapter(model)
|
||||||
by_type: dict[ContextType, list[BaseContext]] = {}
|
return adapter.format(contexts)
|
||||||
for context in contexts:
|
|
||||||
ct = context.get_type()
|
|
||||||
if ct not in by_type:
|
|
||||||
by_type[ct] = []
|
|
||||||
by_type[ct].append(context)
|
|
||||||
|
|
||||||
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
|
|
||||||
type_order = [
|
|
||||||
ContextType.SYSTEM,
|
|
||||||
ContextType.TASK,
|
|
||||||
ContextType.KNOWLEDGE,
|
|
||||||
ContextType.CONVERSATION,
|
|
||||||
ContextType.TOOL,
|
|
||||||
]
|
|
||||||
|
|
||||||
parts: list[str] = []
|
|
||||||
for ct in type_order:
|
|
||||||
if ct in by_type:
|
|
||||||
formatted = self._format_type(by_type[ct], ct, model)
|
|
||||||
if formatted:
|
|
||||||
parts.append(formatted)
|
|
||||||
|
|
||||||
return "\n\n".join(parts)
|
|
||||||
|
|
||||||
def _format_type(
|
|
||||||
self,
|
|
||||||
contexts: list[BaseContext],
|
|
||||||
context_type: ContextType,
|
|
||||||
model: str,
|
|
||||||
) -> str:
|
|
||||||
"""Format contexts of a specific type."""
|
|
||||||
if not contexts:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Check if model prefers XML tags (Claude)
|
|
||||||
use_xml = "claude" in model.lower()
|
|
||||||
|
|
||||||
if context_type == ContextType.SYSTEM:
|
|
||||||
return self._format_system(contexts, use_xml)
|
|
||||||
elif context_type == ContextType.TASK:
|
|
||||||
return self._format_task(contexts, use_xml)
|
|
||||||
elif context_type == ContextType.KNOWLEDGE:
|
|
||||||
return self._format_knowledge(contexts, use_xml)
|
|
||||||
elif context_type == ContextType.CONVERSATION:
|
|
||||||
return self._format_conversation(contexts, use_xml)
|
|
||||||
elif context_type == ContextType.TOOL:
|
|
||||||
return self._format_tool(contexts, use_xml)
|
|
||||||
|
|
||||||
return "\n".join(c.content for c in contexts)
|
|
||||||
|
|
||||||
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
||||||
"""Format system contexts."""
|
|
||||||
content = "\n\n".join(c.content for c in contexts)
|
|
||||||
if use_xml:
|
|
||||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
|
||||||
return content
|
|
||||||
|
|
||||||
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
||||||
"""Format task contexts."""
|
|
||||||
content = "\n\n".join(c.content for c in contexts)
|
|
||||||
if use_xml:
|
|
||||||
return f"<current_task>\n{content}\n</current_task>"
|
|
||||||
return f"## Current Task\n\n{content}"
|
|
||||||
|
|
||||||
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
||||||
"""Format knowledge contexts."""
|
|
||||||
if use_xml:
|
|
||||||
parts = ["<reference_documents>"]
|
|
||||||
for ctx in contexts:
|
|
||||||
parts.append(f'<document source="{ctx.source}">')
|
|
||||||
parts.append(ctx.content)
|
|
||||||
parts.append("</document>")
|
|
||||||
parts.append("</reference_documents>")
|
|
||||||
return "\n".join(parts)
|
|
||||||
else:
|
|
||||||
parts = ["## Reference Documents\n"]
|
|
||||||
for ctx in contexts:
|
|
||||||
parts.append(f"### Source: {ctx.source}\n")
|
|
||||||
parts.append(ctx.content)
|
|
||||||
parts.append("")
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
||||||
"""Format conversation contexts."""
|
|
||||||
if use_xml:
|
|
||||||
parts = ["<conversation_history>"]
|
|
||||||
for ctx in contexts:
|
|
||||||
role = ctx.metadata.get("role", "user")
|
|
||||||
parts.append(f'<message role="{role}">')
|
|
||||||
parts.append(ctx.content)
|
|
||||||
parts.append("</message>")
|
|
||||||
parts.append("</conversation_history>")
|
|
||||||
return "\n".join(parts)
|
|
||||||
else:
|
|
||||||
parts = []
|
|
||||||
for ctx in contexts:
|
|
||||||
role = ctx.metadata.get("role", "user")
|
|
||||||
parts.append(f"**{role.upper()}**: {ctx.content}")
|
|
||||||
return "\n\n".join(parts)
|
|
||||||
|
|
||||||
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
|
|
||||||
"""Format tool contexts."""
|
|
||||||
if use_xml:
|
|
||||||
parts = ["<tool_results>"]
|
|
||||||
for ctx in contexts:
|
|
||||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
|
||||||
parts.append(f'<tool_result name="{tool_name}">')
|
|
||||||
parts.append(ctx.content)
|
|
||||||
parts.append("</tool_result>")
|
|
||||||
parts.append("</tool_results>")
|
|
||||||
return "\n".join(parts)
|
|
||||||
else:
|
|
||||||
parts = ["## Recent Tool Results\n"]
|
|
||||||
for ctx in contexts:
|
|
||||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
|
||||||
parts.append(f"### Tool: {tool_name}\n")
|
|
||||||
parts.append(f"```\n{ctx.content}\n```")
|
|
||||||
parts.append("")
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
def _check_timeout(
|
def _check_timeout(
|
||||||
self,
|
self,
|
||||||
@@ -412,9 +335,28 @@ class ContextPipeline:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Check if timeout exceeded and raise if so."""
|
"""Check if timeout exceeded and raise if so."""
|
||||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
if elapsed_ms > timeout_ms:
|
if elapsed_ms >= timeout_ms:
|
||||||
raise AssemblyTimeoutError(
|
raise AssemblyTimeoutError(
|
||||||
message=f"Context assembly timed out during {phase}",
|
message=f"Context assembly timed out during {phase}",
|
||||||
elapsed_ms=elapsed_ms,
|
elapsed_ms=elapsed_ms,
|
||||||
timeout_ms=timeout_ms,
|
timeout_ms=timeout_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
||||||
|
"""
|
||||||
|
Calculate remaining timeout in seconds for asyncio.wait_for.
|
||||||
|
|
||||||
|
Returns at least a small positive value to avoid immediate timeout
|
||||||
|
edge cases with wait_for.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start: Start time from time.perf_counter()
|
||||||
|
timeout_ms: Total timeout in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining timeout in seconds (minimum 0.001)
|
||||||
|
"""
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
remaining_ms = timeout_ms - elapsed_ms
|
||||||
|
# Return at least 1ms to avoid zero/negative timeout edge cases
|
||||||
|
return max(remaining_ms / 1000.0, 0.001)
|
||||||
|
|||||||
@@ -293,14 +293,18 @@ class BudgetAllocator:
|
|||||||
if isinstance(context_type, ContextType):
|
if isinstance(context_type, ContextType):
|
||||||
context_type = context_type.value
|
context_type = context_type.value
|
||||||
|
|
||||||
# Calculate adjustment (limited by buffer)
|
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
|
||||||
if adjustment > 0:
|
if adjustment > 0:
|
||||||
# Taking from buffer
|
# Taking from buffer - limited by available buffer
|
||||||
actual_adjustment = min(adjustment, budget.buffer)
|
actual_adjustment = min(adjustment, budget.buffer)
|
||||||
budget.buffer -= actual_adjustment
|
budget.buffer -= actual_adjustment
|
||||||
else:
|
else:
|
||||||
# Returning to buffer
|
# Returning to buffer - limited by current allocation of target type
|
||||||
actual_adjustment = adjustment
|
current_allocation = budget.get_allocation(context_type)
|
||||||
|
# Can't return more than current allocation
|
||||||
|
actual_adjustment = max(adjustment, -current_allocation)
|
||||||
|
# Add returned tokens back to buffer (adjustment is negative, so subtract)
|
||||||
|
budget.buffer -= actual_adjustment
|
||||||
|
|
||||||
# Apply to target type
|
# Apply to target type
|
||||||
if context_type == "system":
|
if context_type == "system":
|
||||||
|
|||||||
@@ -95,19 +95,28 @@ class ContextCache:
|
|||||||
contexts: list[BaseContext],
|
contexts: list[BaseContext],
|
||||||
query: str,
|
query: str,
|
||||||
model: str,
|
model: str,
|
||||||
|
project_id: str | None = None,
|
||||||
|
agent_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Compute a fingerprint for a context assembly request.
|
Compute a fingerprint for a context assembly request.
|
||||||
|
|
||||||
The fingerprint is based on:
|
The fingerprint is based on:
|
||||||
|
- Project and agent IDs (for tenant isolation)
|
||||||
- Context content hash and metadata (not full content for performance)
|
- Context content hash and metadata (not full content for performance)
|
||||||
- Query string
|
- Query string
|
||||||
- Target model
|
- Target model
|
||||||
|
|
||||||
|
SECURITY: project_id and agent_id MUST be included to prevent
|
||||||
|
cross-tenant cache pollution. Without these, one tenant could
|
||||||
|
receive cached contexts from another tenant with the same query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
contexts: List of contexts
|
contexts: List of contexts
|
||||||
query: Query string
|
query: Query string
|
||||||
model: Model name
|
model: Model name
|
||||||
|
project_id: Project ID for tenant isolation
|
||||||
|
agent_id: Agent ID for tenant isolation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
32-character hex fingerprint
|
32-character hex fingerprint
|
||||||
@@ -128,6 +137,9 @@ class ContextCache:
|
|||||||
)
|
)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
# CRITICAL: Include tenant identifiers for cache isolation
|
||||||
|
"project_id": project_id or "",
|
||||||
|
"agent_id": agent_id or "",
|
||||||
"contexts": context_data,
|
"contexts": context_data,
|
||||||
"query": query,
|
"query": query,
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|||||||
@@ -19,6 +19,40 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_tokens(text: str, model: str | None = None) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count using model-specific character ratios.
|
||||||
|
|
||||||
|
Module-level function for reuse across classes. Uses the same ratios
|
||||||
|
as TokenCalculator for consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to estimate tokens for
|
||||||
|
model: Optional model name for model-specific ratios
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count (minimum 1)
|
||||||
|
"""
|
||||||
|
# Model-specific character ratios (chars per token)
|
||||||
|
model_ratios = {
|
||||||
|
"claude": 3.5,
|
||||||
|
"gpt-4": 4.0,
|
||||||
|
"gpt-3.5": 4.0,
|
||||||
|
"gemini": 4.0,
|
||||||
|
}
|
||||||
|
default_ratio = 4.0
|
||||||
|
|
||||||
|
ratio = default_ratio
|
||||||
|
if model:
|
||||||
|
model_lower = model.lower()
|
||||||
|
for model_prefix, model_ratio in model_ratios.items():
|
||||||
|
if model_prefix in model_lower:
|
||||||
|
ratio = model_ratio
|
||||||
|
break
|
||||||
|
|
||||||
|
return max(1, int(len(text) / ratio))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TruncationResult:
|
class TruncationResult:
|
||||||
"""Result of truncation operation."""
|
"""Result of truncation operation."""
|
||||||
@@ -284,8 +318,8 @@ class TruncationStrategy:
|
|||||||
if self._calculator is not None:
|
if self._calculator is not None:
|
||||||
return await self._calculator.count_tokens(text, model)
|
return await self._calculator.count_tokens(text, model)
|
||||||
|
|
||||||
# Fallback estimation
|
# Fallback estimation with model-specific ratios
|
||||||
return max(1, len(text) // 4)
|
return _estimate_tokens(text, model)
|
||||||
|
|
||||||
|
|
||||||
class ContextCompressor:
|
class ContextCompressor:
|
||||||
@@ -415,4 +449,5 @@ class ContextCompressor:
|
|||||||
"""Count tokens using calculator or estimation."""
|
"""Count tokens using calculator or estimation."""
|
||||||
if self._calculator is not None:
|
if self._calculator is not None:
|
||||||
return await self._calculator.count_tokens(text, model)
|
return await self._calculator.count_tokens(text, model)
|
||||||
return max(1, len(text) // 4)
|
# Use model-specific estimation for consistency
|
||||||
|
return _estimate_tokens(text, model)
|
||||||
|
|||||||
@@ -149,10 +149,11 @@ class ContextSettings(BaseSettings):
|
|||||||
|
|
||||||
# Performance settings
|
# Performance settings
|
||||||
max_assembly_time_ms: int = Field(
|
max_assembly_time_ms: int = Field(
|
||||||
default=100,
|
default=2000,
|
||||||
ge=10,
|
ge=10,
|
||||||
le=5000,
|
le=30000,
|
||||||
description="Maximum time for context assembly in milliseconds",
|
description="Maximum time for context assembly in milliseconds. "
|
||||||
|
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
|
||||||
)
|
)
|
||||||
parallel_scoring: bool = Field(
|
parallel_scoring: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
|
|||||||
@@ -212,7 +212,10 @@ class ContextEngine:
|
|||||||
# Check cache if enabled
|
# Check cache if enabled
|
||||||
fingerprint: str | None = None
|
fingerprint: str | None = None
|
||||||
if use_cache and self._cache.is_enabled:
|
if use_cache and self._cache.is_enabled:
|
||||||
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
|
# Include project_id and agent_id for tenant isolation
|
||||||
|
fingerprint = self._cache.compute_fingerprint(
|
||||||
|
contexts, query, model, project_id=project_id, agent_id=agent_id
|
||||||
|
)
|
||||||
cached = await self._cache.get_assembled(fingerprint)
|
cached = await self._cache.get_assembled(fingerprint)
|
||||||
if cached:
|
if cached:
|
||||||
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
from ..budget import TokenBudget, TokenCalculator
|
from ..budget import TokenBudget, TokenCalculator
|
||||||
from ..config import ContextSettings, get_context_settings
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import BudgetExceededError
|
||||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||||
from ..types import BaseContext, ContextPriority
|
from ..types import BaseContext, ContextPriority
|
||||||
|
|
||||||
@@ -127,9 +128,25 @@ class ContextRanker:
|
|||||||
excluded: list[ScoredContext] = []
|
excluded: list[ScoredContext] = []
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
|
# Calculate the usable budget (total minus reserved portions)
|
||||||
|
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||||
|
|
||||||
|
# Guard against invalid budget configuration
|
||||||
|
if usable_budget <= 0:
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=(
|
||||||
|
f"Invalid budget configuration: no usable tokens available. "
|
||||||
|
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
||||||
|
f"buffer={budget.buffer}"
|
||||||
|
),
|
||||||
|
allocated=budget.total,
|
||||||
|
requested=0,
|
||||||
|
context_type="CONFIGURATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
# First, try to fit required contexts
|
# First, try to fit required contexts
|
||||||
for sc in required:
|
for sc in required:
|
||||||
token_count = sc.context.token_count or 0
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
context_type = sc.context.get_type()
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
if budget.can_fit(context_type, token_count):
|
if budget.can_fit(context_type, token_count):
|
||||||
@@ -137,7 +154,20 @@ class ContextRanker:
|
|||||||
selected.append(sc)
|
selected.append(sc)
|
||||||
total_tokens += token_count
|
total_tokens += token_count
|
||||||
else:
|
else:
|
||||||
# Force-fit CRITICAL contexts if needed
|
# Force-fit CRITICAL contexts if needed, but check total budget first
|
||||||
|
if total_tokens + token_count > usable_budget:
|
||||||
|
# Even CRITICAL contexts cannot exceed total model context window
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=(
|
||||||
|
f"CRITICAL contexts exceed total budget. "
|
||||||
|
f"Context '{sc.context.source}' ({token_count} tokens) "
|
||||||
|
f"would exceed usable budget of {usable_budget} tokens."
|
||||||
|
),
|
||||||
|
allocated=usable_budget,
|
||||||
|
requested=total_tokens + token_count,
|
||||||
|
context_type="CRITICAL_OVERFLOW",
|
||||||
|
)
|
||||||
|
|
||||||
budget.allocate(context_type, token_count, force=True)
|
budget.allocate(context_type, token_count, force=True)
|
||||||
selected.append(sc)
|
selected.append(sc)
|
||||||
total_tokens += token_count
|
total_tokens += token_count
|
||||||
@@ -148,7 +178,7 @@ class ContextRanker:
|
|||||||
|
|
||||||
# Then, greedily add optional contexts
|
# Then, greedily add optional contexts
|
||||||
for sc in optional:
|
for sc in optional:
|
||||||
token_count = sc.context.token_count or 0
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
context_type = sc.context.get_type()
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
if budget.can_fit(context_type, token_count):
|
if budget.can_fit(context_type, token_count):
|
||||||
@@ -215,13 +245,43 @@ class ContextRanker:
|
|||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
for sc in scored_contexts:
|
for sc in scored_contexts:
|
||||||
token_count = sc.context.token_count or 0
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
if total_tokens + token_count <= max_tokens:
|
if total_tokens + token_count <= max_tokens:
|
||||||
selected.append(sc.context)
|
selected.append(sc.context)
|
||||||
total_tokens += token_count
|
total_tokens += token_count
|
||||||
|
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
def _get_valid_token_count(self, context: BaseContext) -> int:
|
||||||
|
"""
|
||||||
|
Get validated token count from a context.
|
||||||
|
|
||||||
|
Ensures token_count is set (not None) and non-negative to prevent
|
||||||
|
budget bypass attacks where:
|
||||||
|
- None would be treated as 0 (allowing huge contexts to slip through)
|
||||||
|
- Negative values would corrupt budget tracking
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to get token count from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Valid non-negative token count
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If token_count is None or negative
|
||||||
|
"""
|
||||||
|
if context.token_count is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Context '{context.source}' has no token count. "
|
||||||
|
"Ensure _ensure_token_counts() is called before ranking."
|
||||||
|
)
|
||||||
|
if context.token_count < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Context '{context.source}' has invalid negative token count: "
|
||||||
|
f"{context.token_count}"
|
||||||
|
)
|
||||||
|
return context.token_count
|
||||||
|
|
||||||
async def _ensure_token_counts(
|
async def _ensure_token_counts(
|
||||||
self,
|
self,
|
||||||
contexts: list[BaseContext],
|
contexts: list[BaseContext],
|
||||||
@@ -266,6 +326,7 @@ class ContextRanker:
|
|||||||
if type_name not in by_type:
|
if type_name not in by_type:
|
||||||
by_type[type_name] = {"count": 0, "tokens": 0}
|
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||||
by_type[type_name]["count"] += 1
|
by_type[type_name]["count"] += 1
|
||||||
|
# Use validated token count (already validated during ranking)
|
||||||
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||||
|
|
||||||
return by_type
|
return by_type
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ Combines multiple scoring strategies with configurable weights.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
from weakref import WeakValueDictionary
|
|
||||||
|
|
||||||
from ..config import ContextSettings, get_context_settings
|
from ..config import ContextSettings, get_context_settings
|
||||||
from ..types import BaseContext
|
from ..types import BaseContext
|
||||||
@@ -91,11 +91,11 @@ class CompositeScorer:
|
|||||||
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||||
|
|
||||||
# Per-context locks to prevent race conditions during parallel scoring
|
# Per-context locks to prevent race conditions during parallel scoring
|
||||||
# Uses WeakValueDictionary so locks are garbage collected when not in use
|
# Uses dict with (lock, last_used_time) tuples for cleanup
|
||||||
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = (
|
self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
|
||||||
WeakValueDictionary()
|
|
||||||
)
|
|
||||||
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||||
|
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
|
||||||
|
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
|
||||||
|
|
||||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
"""Set MCP manager for semantic scoring."""
|
"""Set MCP manager for semantic scoring."""
|
||||||
@@ -141,7 +141,8 @@ class CompositeScorer:
|
|||||||
Get or create a lock for a specific context.
|
Get or create a lock for a specific context.
|
||||||
|
|
||||||
Thread-safe access to per-context locks prevents race conditions
|
Thread-safe access to per-context locks prevents race conditions
|
||||||
when the same context is scored concurrently.
|
when the same context is scored concurrently. Includes automatic
|
||||||
|
cleanup of old locks to prevent memory growth.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context_id: The context ID to get a lock for
|
context_id: The context ID to get a lock for
|
||||||
@@ -149,25 +150,78 @@ class CompositeScorer:
|
|||||||
Returns:
|
Returns:
|
||||||
asyncio.Lock for the context
|
asyncio.Lock for the context
|
||||||
"""
|
"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
# Fast path: check if lock exists without acquiring main lock
|
# Fast path: check if lock exists without acquiring main lock
|
||||||
if context_id in self._context_locks:
|
# NOTE: We only READ here - no writes to avoid race conditions
|
||||||
lock = self._context_locks.get(context_id)
|
# with cleanup. The timestamp will be updated in the slow path
|
||||||
if lock is not None:
|
# if the lock is still valid.
|
||||||
|
lock_entry = self._context_locks.get(context_id)
|
||||||
|
if lock_entry is not None:
|
||||||
|
lock, _ = lock_entry
|
||||||
|
# Return the lock but defer timestamp update to avoid race
|
||||||
|
# The lock is still valid; timestamp update is best-effort
|
||||||
return lock
|
return lock
|
||||||
|
|
||||||
# Slow path: create lock while holding main lock
|
# Slow path: create lock or update timestamp while holding main lock
|
||||||
async with self._locks_lock:
|
async with self._locks_lock:
|
||||||
# Double-check after acquiring lock
|
# Double-check after acquiring lock - entry may have been
|
||||||
if context_id in self._context_locks:
|
# created by another coroutine or deleted by cleanup
|
||||||
lock = self._context_locks.get(context_id)
|
lock_entry = self._context_locks.get(context_id)
|
||||||
if lock is not None:
|
if lock_entry is not None:
|
||||||
|
lock, _ = lock_entry
|
||||||
|
# Safe to update timestamp here since we hold the lock
|
||||||
|
self._context_locks[context_id] = (lock, now)
|
||||||
return lock
|
return lock
|
||||||
|
|
||||||
|
# Cleanup old locks if we have too many
|
||||||
|
if len(self._context_locks) >= self._max_locks:
|
||||||
|
self._cleanup_old_locks(now)
|
||||||
|
|
||||||
# Create new lock
|
# Create new lock
|
||||||
new_lock = asyncio.Lock()
|
new_lock = asyncio.Lock()
|
||||||
self._context_locks[context_id] = new_lock
|
self._context_locks[context_id] = (new_lock, now)
|
||||||
return new_lock
|
return new_lock
|
||||||
|
|
||||||
|
def _cleanup_old_locks(self, now: float) -> None:
|
||||||
|
"""
|
||||||
|
Remove old locks that haven't been used recently.
|
||||||
|
|
||||||
|
Called while holding _locks_lock. Removes locks older than _lock_ttl,
|
||||||
|
but only if they're not currently held.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
now: Current timestamp for age calculation
|
||||||
|
"""
|
||||||
|
cutoff = now - self._lock_ttl
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for context_id, (lock, last_used) in self._context_locks.items():
|
||||||
|
# Only remove if old AND not currently held
|
||||||
|
if last_used < cutoff and not lock.locked():
|
||||||
|
to_remove.append(context_id)
|
||||||
|
|
||||||
|
# Remove oldest 50% if still over limit after TTL filtering
|
||||||
|
if len(self._context_locks) - len(to_remove) >= self._max_locks:
|
||||||
|
# Sort by last used time and mark oldest for removal
|
||||||
|
sorted_entries = sorted(
|
||||||
|
self._context_locks.items(),
|
||||||
|
key=lambda x: x[1][1], # Sort by last_used time
|
||||||
|
)
|
||||||
|
# Remove oldest 50% that aren't locked
|
||||||
|
target_remove = len(self._context_locks) // 2
|
||||||
|
for context_id, (lock, _) in sorted_entries:
|
||||||
|
if len(to_remove) >= target_remove:
|
||||||
|
break
|
||||||
|
if context_id not in to_remove and not lock.locked():
|
||||||
|
to_remove.append(context_id)
|
||||||
|
|
||||||
|
for context_id in to_remove:
|
||||||
|
del self._context_locks[context_id]
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
logger.debug(f"Cleaned up {len(to_remove)} context locks")
|
||||||
|
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
context: BaseContext,
|
context: BaseContext,
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ from ..models import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Sentinel for distinguishing "no argument passed" from "explicitly passing None"
|
||||||
|
_UNSET = object()
|
||||||
|
|
||||||
|
|
||||||
class AuditLogger:
|
class AuditLogger:
|
||||||
"""
|
"""
|
||||||
@@ -142,8 +145,10 @@ class AuditLogger:
|
|||||||
# Add hash chain for tamper detection
|
# Add hash chain for tamper detection
|
||||||
if self._enable_hash_chain:
|
if self._enable_hash_chain:
|
||||||
event_hash = self._compute_hash(event)
|
event_hash = self._compute_hash(event)
|
||||||
sanitized_details["_hash"] = event_hash
|
# Modify event.details directly (not sanitized_details)
|
||||||
sanitized_details["_prev_hash"] = self._last_hash
|
# to ensure the hash is stored on the actual event
|
||||||
|
event.details["_hash"] = event_hash
|
||||||
|
event.details["_prev_hash"] = self._last_hash
|
||||||
self._last_hash = event_hash
|
self._last_hash = event_hash
|
||||||
|
|
||||||
self._buffer.append(event)
|
self._buffer.append(event)
|
||||||
@@ -415,7 +420,8 @@ class AuditLogger:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if stored_hash:
|
if stored_hash:
|
||||||
computed = self._compute_hash(event)
|
# Pass prev_hash to compute hash with correct chain position
|
||||||
|
computed = self._compute_hash(event, prev_hash=prev_hash)
|
||||||
if computed != stored_hash:
|
if computed != stored_hash:
|
||||||
issues.append(
|
issues.append(
|
||||||
f"Hash mismatch at event {event.id}: "
|
f"Hash mismatch at event {event.id}: "
|
||||||
@@ -462,9 +468,23 @@ class AuditLogger:
|
|||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
def _compute_hash(self, event: AuditEvent) -> str:
|
def _compute_hash(
|
||||||
"""Compute hash for an event (excluding hash fields)."""
|
self, event: AuditEvent, prev_hash: str | None | object = _UNSET
|
||||||
data = {
|
) -> str:
|
||||||
|
"""Compute hash for an event (excluding hash fields).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The audit event to hash.
|
||||||
|
prev_hash: Optional previous hash to use instead of self._last_hash.
|
||||||
|
Pass this during verification to use the correct chain.
|
||||||
|
Use None explicitly to indicate no previous hash.
|
||||||
|
"""
|
||||||
|
# Use passed prev_hash if explicitly provided, otherwise use instance state
|
||||||
|
effective_prev: str | None = (
|
||||||
|
self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment]
|
||||||
|
)
|
||||||
|
|
||||||
|
data: dict[str, str | dict[str, str] | None] = {
|
||||||
"id": event.id,
|
"id": event.id,
|
||||||
"event_type": event.event_type.value,
|
"event_type": event.event_type.value,
|
||||||
"timestamp": event.timestamp.isoformat(),
|
"timestamp": event.timestamp.isoformat(),
|
||||||
@@ -480,8 +500,8 @@ class AuditLogger:
|
|||||||
"correlation_id": event.correlation_id,
|
"correlation_id": event.correlation_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._last_hash:
|
if effective_prev:
|
||||||
data["_prev_hash"] = self._last_hash
|
data["_prev_hash"] = effective_prev
|
||||||
|
|
||||||
serialized = json.dumps(data, sort_keys=True, default=str)
|
serialized = json.dumps(data, sort_keys=True, default=str)
|
||||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||||
|
|||||||
466
backend/tests/api/routes/test_context.py
Normal file
466
backend/tests/api/routes/test_context.py
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
"""
|
||||||
|
Tests for Context Management API Routes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import status
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.context import (
|
||||||
|
AssembledContext,
|
||||||
|
AssemblyTimeoutError,
|
||||||
|
BudgetExceededError,
|
||||||
|
ContextEngine,
|
||||||
|
TokenBudget,
|
||||||
|
)
|
||||||
|
from app.services.mcp import MCPClientManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_mcp_client():
|
||||||
|
"""Create a mock MCP client manager."""
|
||||||
|
client = MagicMock(spec=MCPClientManager)
|
||||||
|
client.is_initialized = True
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_context_engine(mock_mcp_client):
|
||||||
|
"""Create a mock ContextEngine."""
|
||||||
|
engine = MagicMock(spec=ContextEngine)
|
||||||
|
engine._mcp = mock_mcp_client
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_superuser():
|
||||||
|
"""Create a mock superuser."""
|
||||||
|
user = MagicMock(spec=User)
|
||||||
|
user.id = "00000000-0000-0000-0000-000000000001"
|
||||||
|
user.is_superuser = True
|
||||||
|
user.email = "admin@example.com"
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_mcp_client, mock_context_engine, mock_superuser):
|
||||||
|
"""Create a FastAPI test client with mocked dependencies."""
|
||||||
|
from app.api.dependencies.permissions import require_superuser
|
||||||
|
from app.api.routes.context import get_context_engine
|
||||||
|
from app.services.mcp import get_mcp_client
|
||||||
|
|
||||||
|
# Override dependencies
|
||||||
|
async def override_get_mcp_client():
|
||||||
|
return mock_mcp_client
|
||||||
|
|
||||||
|
async def override_get_context_engine():
|
||||||
|
return mock_context_engine
|
||||||
|
|
||||||
|
async def override_require_superuser():
|
||||||
|
return mock_superuser
|
||||||
|
|
||||||
|
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
|
||||||
|
app.dependency_overrides[get_context_engine] = override_get_context_engine
|
||||||
|
app.dependency_overrides[require_superuser] = override_require_superuser
|
||||||
|
|
||||||
|
with patch("app.main.check_database_health", return_value=True):
|
||||||
|
yield TestClient(app)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextHealth:
|
||||||
|
"""Tests for GET /context/health endpoint."""
|
||||||
|
|
||||||
|
def test_health_check_success(self, client, mock_context_engine, mock_mcp_client):
|
||||||
|
"""Test context engine health check."""
|
||||||
|
mock_context_engine.get_stats = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"cache": {"hits": 10, "misses": 5},
|
||||||
|
"settings": {"cache_enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/context/health")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
assert "mcp_connected" in data
|
||||||
|
assert "cache_enabled" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssembleContext:
|
||||||
|
"""Tests for POST /context/assemble endpoint."""
|
||||||
|
|
||||||
|
def test_assemble_context_success(self, client, mock_context_engine):
|
||||||
|
"""Test successful context assembly."""
|
||||||
|
# Create mock assembled context
|
||||||
|
mock_result = MagicMock(spec=AssembledContext)
|
||||||
|
mock_result.content = "Assembled context content"
|
||||||
|
mock_result.total_tokens = 500
|
||||||
|
mock_result.context_count = 2
|
||||||
|
mock_result.excluded_count = 0
|
||||||
|
mock_result.assembly_time_ms = 50.5
|
||||||
|
mock_result.metadata = {}
|
||||||
|
|
||||||
|
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||||
|
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||||
|
return_value=TokenBudget(
|
||||||
|
total=4000,
|
||||||
|
system=500,
|
||||||
|
knowledge=1500,
|
||||||
|
conversation=1000,
|
||||||
|
tools=500,
|
||||||
|
response_reserve=500,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
json={
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"query": "What is the auth flow?",
|
||||||
|
"model": "claude-3-sonnet",
|
||||||
|
"system_prompt": "You are a helpful assistant.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["content"] == "Assembled context content"
|
||||||
|
assert data["total_tokens"] == 500
|
||||||
|
assert data["context_count"] == 2
|
||||||
|
assert data["compressed"] is False
|
||||||
|
assert "budget_used_percent" in data
|
||||||
|
|
||||||
|
def test_assemble_context_with_conversation(self, client, mock_context_engine):
|
||||||
|
"""Test context assembly with conversation history."""
|
||||||
|
mock_result = MagicMock(spec=AssembledContext)
|
||||||
|
mock_result.content = "Context with history"
|
||||||
|
mock_result.total_tokens = 800
|
||||||
|
mock_result.context_count = 1
|
||||||
|
mock_result.excluded_count = 0
|
||||||
|
mock_result.assembly_time_ms = 30.0
|
||||||
|
mock_result.metadata = {}
|
||||||
|
|
||||||
|
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||||
|
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||||
|
return_value=TokenBudget(
|
||||||
|
total=4000,
|
||||||
|
system=500,
|
||||||
|
knowledge=1500,
|
||||||
|
conversation=1000,
|
||||||
|
tools=500,
|
||||||
|
response_reserve=500,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
json={
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"query": "Continue the discussion",
|
||||||
|
"conversation_history": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
call_args = mock_context_engine.assemble_context.call_args
|
||||||
|
assert call_args.kwargs["conversation_history"] == [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_assemble_context_with_tool_results(self, client, mock_context_engine):
|
||||||
|
"""Test context assembly with tool results."""
|
||||||
|
mock_result = MagicMock(spec=AssembledContext)
|
||||||
|
mock_result.content = "Context with tools"
|
||||||
|
mock_result.total_tokens = 600
|
||||||
|
mock_result.context_count = 1
|
||||||
|
mock_result.excluded_count = 0
|
||||||
|
mock_result.assembly_time_ms = 25.0
|
||||||
|
mock_result.metadata = {}
|
||||||
|
|
||||||
|
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||||
|
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||||
|
return_value=TokenBudget(
|
||||||
|
total=4000,
|
||||||
|
system=500,
|
||||||
|
knowledge=1500,
|
||||||
|
conversation=1000,
|
||||||
|
tools=500,
|
||||||
|
response_reserve=500,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
json={
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"query": "What did the search find?",
|
||||||
|
"tool_results": [
|
||||||
|
{
|
||||||
|
"tool_name": "search_knowledge",
|
||||||
|
"content": {"results": ["item1", "item2"]},
|
||||||
|
"status": "success",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
call_args = mock_context_engine.assemble_context.call_args
|
||||||
|
assert len(call_args.kwargs["tool_results"]) == 1
|
||||||
|
|
||||||
|
def test_assemble_context_timeout(self, client, mock_context_engine):
|
||||||
|
"""Test context assembly timeout error."""
|
||||||
|
mock_context_engine.assemble_context = AsyncMock(
|
||||||
|
side_effect=AssemblyTimeoutError("Assembly exceeded 5000ms limit")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
json={
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"query": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
|
||||||
|
|
||||||
|
def test_assemble_context_budget_exceeded(self, client, mock_context_engine):
|
||||||
|
"""Test context assembly budget exceeded error."""
|
||||||
|
mock_context_engine.assemble_context = AsyncMock(
|
||||||
|
side_effect=BudgetExceededError("Token budget exceeded: 5000 > 4000")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
json={
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"query": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
||||||
|
|
||||||
|
def test_assemble_context_validation_error(self, client):
|
||||||
|
"""Test context assembly with invalid request."""
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
json={}, # Missing required fields
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
|
||||||
|
class TestCountTokens:
|
||||||
|
"""Tests for POST /context/count-tokens endpoint."""
|
||||||
|
|
||||||
|
def test_count_tokens_success(self, client, mock_context_engine):
|
||||||
|
"""Test successful token counting."""
|
||||||
|
mock_context_engine.count_tokens = AsyncMock(return_value=42)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/count-tokens",
|
||||||
|
json={
|
||||||
|
"content": "This is some test content.",
|
||||||
|
"model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["token_count"] == 42
|
||||||
|
assert data["model"] == "claude-3-sonnet"
|
||||||
|
|
||||||
|
def test_count_tokens_without_model(self, client, mock_context_engine):
|
||||||
|
"""Test token counting without specifying model."""
|
||||||
|
mock_context_engine.count_tokens = AsyncMock(return_value=100)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/count-tokens",
|
||||||
|
json={"content": "Some content to count."},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["token_count"] == 100
|
||||||
|
assert data["model"] is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetBudget:
|
||||||
|
"""Tests for GET /context/budget/{model} endpoint."""
|
||||||
|
|
||||||
|
def test_get_budget_success(self, client, mock_context_engine):
|
||||||
|
"""Test getting token budget for a model."""
|
||||||
|
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||||
|
return_value=TokenBudget(
|
||||||
|
total=100000,
|
||||||
|
system=10000,
|
||||||
|
knowledge=40000,
|
||||||
|
conversation=30000,
|
||||||
|
tools=10000,
|
||||||
|
response_reserve=10000,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/context/budget/claude-3-opus")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["model"] == "claude-3-opus"
|
||||||
|
assert data["total_tokens"] == 100000
|
||||||
|
assert data["system_tokens"] == 10000
|
||||||
|
assert data["knowledge_tokens"] == 40000
|
||||||
|
|
||||||
|
def test_get_budget_with_max_tokens(self, client, mock_context_engine):
|
||||||
|
"""Test getting budget with custom max tokens."""
|
||||||
|
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||||
|
return_value=TokenBudget(
|
||||||
|
total=2000,
|
||||||
|
system=200,
|
||||||
|
knowledge=800,
|
||||||
|
conversation=600,
|
||||||
|
tools=200,
|
||||||
|
response_reserve=200,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/context/budget/gpt-4?max_tokens=2000")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["total_tokens"] == 2000
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetStats:
|
||||||
|
"""Tests for GET /context/stats endpoint."""
|
||||||
|
|
||||||
|
def test_get_stats_success(self, client, mock_context_engine):
|
||||||
|
"""Test getting engine statistics."""
|
||||||
|
mock_context_engine.get_stats = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"cache": {
|
||||||
|
"hits": 100,
|
||||||
|
"misses": 25,
|
||||||
|
"hit_rate": 0.8,
|
||||||
|
},
|
||||||
|
"settings": {
|
||||||
|
"compression_threshold": 0.9,
|
||||||
|
"max_assembly_time_ms": 5000,
|
||||||
|
"cache_enabled": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/context/stats")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["cache"]["hits"] == 100
|
||||||
|
assert data["settings"]["cache_enabled"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvalidateCache:
|
||||||
|
"""Tests for POST /context/cache/invalidate endpoint."""
|
||||||
|
|
||||||
|
def test_invalidate_cache_by_project(self, client, mock_context_engine):
|
||||||
|
"""Test cache invalidation by project ID."""
|
||||||
|
mock_context_engine.invalidate_cache = AsyncMock(return_value=5)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/cache/invalidate?project_id=test-project"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
mock_context_engine.invalidate_cache.assert_called_once()
|
||||||
|
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
|
||||||
|
assert call_kwargs["project_id"] == "test-project"
|
||||||
|
|
||||||
|
def test_invalidate_cache_by_pattern(self, client, mock_context_engine):
|
||||||
|
"""Test cache invalidation by pattern."""
|
||||||
|
mock_context_engine.invalidate_cache = AsyncMock(return_value=10)
|
||||||
|
|
||||||
|
response = client.post("/api/v1/context/cache/invalidate?pattern=*auth*")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
mock_context_engine.invalidate_cache.assert_called_once()
|
||||||
|
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
|
||||||
|
assert call_kwargs["pattern"] == "*auth*"
|
||||||
|
|
||||||
|
def test_invalidate_cache_all(self, client, mock_context_engine):
|
||||||
|
"""Test invalidating all cache entries."""
|
||||||
|
mock_context_engine.invalidate_cache = AsyncMock(return_value=100)
|
||||||
|
|
||||||
|
response = client.post("/api/v1/context/cache/invalidate")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEndpointsEdgeCases:
|
||||||
|
"""Edge case tests for Context endpoints."""
|
||||||
|
|
||||||
|
def test_context_content_type(self, client, mock_context_engine):
|
||||||
|
"""Test that endpoints return JSON content type."""
|
||||||
|
mock_context_engine.get_stats = AsyncMock(
|
||||||
|
return_value={"cache": {}, "settings": {}}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/context/health")
|
||||||
|
|
||||||
|
assert "application/json" in response.headers["content-type"]
|
||||||
|
|
||||||
|
def test_assemble_context_with_knowledge_query(self, client, mock_context_engine):
|
||||||
|
"""Test context assembly with knowledge base query."""
|
||||||
|
mock_result = MagicMock(spec=AssembledContext)
|
||||||
|
mock_result.content = "Context with knowledge"
|
||||||
|
mock_result.total_tokens = 1000
|
||||||
|
mock_result.context_count = 3
|
||||||
|
mock_result.excluded_count = 0
|
||||||
|
mock_result.assembly_time_ms = 100.0
|
||||||
|
mock_result.metadata = {
|
||||||
|
"compressed_contexts": 1
|
||||||
|
} # Indicates compression happened
|
||||||
|
|
||||||
|
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||||
|
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||||
|
return_value=TokenBudget(
|
||||||
|
total=4000,
|
||||||
|
system=500,
|
||||||
|
knowledge=1500,
|
||||||
|
conversation=1000,
|
||||||
|
tools=500,
|
||||||
|
response_reserve=500,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
json={
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"query": "How does authentication work?",
|
||||||
|
"knowledge_query": "authentication flow implementation",
|
||||||
|
"knowledge_limit": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
call_kwargs = mock_context_engine.assemble_context.call_args.kwargs
|
||||||
|
assert call_kwargs["knowledge_query"] == "authentication flow implementation"
|
||||||
|
assert call_kwargs["knowledge_limit"] == 5
|
||||||
646
backend/tests/e2e/test_agent_workflows.py
Normal file
646
backend/tests/e2e/test_agent_workflows.py
Normal file
@@ -0,0 +1,646 @@
|
|||||||
|
"""
|
||||||
|
Agent E2E Workflow Tests.
|
||||||
|
|
||||||
|
Tests complete workflows for AI agents including:
|
||||||
|
- Agent type management (admin-only)
|
||||||
|
- Agent instance spawning and lifecycle
|
||||||
|
- Agent status transitions (pause/resume/terminate)
|
||||||
|
- Authorization and access control
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
make test-e2e # Run all E2E tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.e2e,
|
||||||
|
pytest.mark.postgres,
|
||||||
|
pytest.mark.asyncio,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypesAdminWorkflows:
|
||||||
|
"""Test agent type management (admin-only operations)."""
|
||||||
|
|
||||||
|
async def test_create_agent_type_requires_superuser(self, e2e_client):
|
||||||
|
"""Test that creating agent types requires superuser privileges."""
|
||||||
|
# Register regular user
|
||||||
|
email = f"regular-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "RegularPass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Regular",
|
||||||
|
"last_name": "User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
# Try to create agent type
|
||||||
|
response = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"name": "Test Agent",
|
||||||
|
"slug": f"test-agent-{uuid4().hex[:8]}",
|
||||||
|
"personality_prompt": "You are a helpful assistant.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
async def test_superuser_can_create_agent_type(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test that superuser can create and manage agent types."""
|
||||||
|
slug = f"test-type-{uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# Create agent type
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "Product Owner Agent",
|
||||||
|
"slug": slug,
|
||||||
|
"description": "A product owner agent for requirements gathering",
|
||||||
|
"expertise": ["requirements", "user_stories", "prioritization"],
|
||||||
|
"personality_prompt": "You are a product owner focused on delivering value.",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
"fallback_models": ["claude-3-sonnet"],
|
||||||
|
"model_params": {"temperature": 0.7, "max_tokens": 4000},
|
||||||
|
"mcp_servers": ["knowledge-base"],
|
||||||
|
"is_active": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert create_resp.status_code == 201, f"Failed: {create_resp.text}"
|
||||||
|
agent_type = create_resp.json()
|
||||||
|
|
||||||
|
assert agent_type["name"] == "Product Owner Agent"
|
||||||
|
assert agent_type["slug"] == slug
|
||||||
|
assert agent_type["primary_model"] == "claude-3-opus"
|
||||||
|
assert agent_type["is_active"] is True
|
||||||
|
assert "requirements" in agent_type["expertise"]
|
||||||
|
|
||||||
|
async def test_list_agent_types_public(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test that any authenticated user can list agent types."""
|
||||||
|
# First create an agent type as superuser
|
||||||
|
slug = f"list-test-{uuid4().hex[:8]}"
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": f"List Test Agent {slug}",
|
||||||
|
"slug": slug,
|
||||||
|
"personality_prompt": "Test agent.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register regular user
|
||||||
|
email = f"lister-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "ListerPass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "List",
|
||||||
|
"last_name": "User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
# List agent types as regular user
|
||||||
|
list_resp = await e2e_client.get(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list_resp.status_code == 200
|
||||||
|
data = list_resp.json()
|
||||||
|
assert "data" in data
|
||||||
|
assert "pagination" in data
|
||||||
|
assert data["pagination"]["total"] >= 1
|
||||||
|
|
||||||
|
async def test_get_agent_type_by_slug(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test getting agent type by slug."""
|
||||||
|
slug = f"slug-test-{uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# Create agent type
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": f"Slug Test {slug}",
|
||||||
|
"slug": slug,
|
||||||
|
"personality_prompt": "Test agent.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get by slug (route is /slug/{slug}, not /by-slug/{slug})
|
||||||
|
get_resp = await e2e_client.get(
|
||||||
|
f"/api/v1/agent-types/slug/{slug}",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert get_resp.status_code == 200
|
||||||
|
data = get_resp.json()
|
||||||
|
assert data["slug"] == slug
|
||||||
|
|
||||||
|
async def test_update_agent_type(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test updating an agent type."""
|
||||||
|
slug = f"update-test-{uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# Create agent type
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "Original Name",
|
||||||
|
"slug": slug,
|
||||||
|
"personality_prompt": "Original prompt.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_type_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
# Update agent type
|
||||||
|
update_resp = await e2e_client.patch(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "Updated Name",
|
||||||
|
"description": "Added description",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert update_resp.status_code == 200
|
||||||
|
updated = update_resp.json()
|
||||||
|
assert updated["name"] == "Updated Name"
|
||||||
|
assert updated["description"] == "Added description"
|
||||||
|
assert updated["personality_prompt"] == "Original prompt." # Unchanged
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceWorkflows:
|
||||||
|
"""Test agent instance spawning and lifecycle."""
|
||||||
|
|
||||||
|
async def test_spawn_agent_workflow(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test complete workflow: create type -> create project -> spawn agent."""
|
||||||
|
# 1. Create agent type as superuser
|
||||||
|
type_slug = f"spawn-test-type-{uuid4().hex[:8]}"
|
||||||
|
type_resp = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "Spawn Test Agent",
|
||||||
|
"slug": type_slug,
|
||||||
|
"personality_prompt": "You are a helpful agent.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert type_resp.status_code == 201
|
||||||
|
agent_type = type_resp.json()
|
||||||
|
agent_type_id = agent_type["id"]
|
||||||
|
|
||||||
|
# 2. Create a project (superuser can create projects too)
|
||||||
|
project_slug = f"spawn-test-project-{uuid4().hex[:8]}"
|
||||||
|
project_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={"name": "Spawn Test Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
assert project_resp.status_code == 201
|
||||||
|
project = project_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# 3. Spawn agent instance
|
||||||
|
spawn_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"name": "My PO Agent",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert spawn_resp.status_code == 201, f"Failed: {spawn_resp.text}"
|
||||||
|
agent = spawn_resp.json()
|
||||||
|
|
||||||
|
assert agent["name"] == "My PO Agent"
|
||||||
|
assert agent["status"] == "idle"
|
||||||
|
assert agent["project_id"] == project_id
|
||||||
|
assert agent["agent_type_id"] == agent_type_id
|
||||||
|
|
||||||
|
async def test_list_project_agents(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test listing agents in a project."""
|
||||||
|
# Setup: Create agent type and project
|
||||||
|
type_slug = f"list-agents-type-{uuid4().hex[:8]}"
|
||||||
|
type_resp = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "List Agents Type",
|
||||||
|
"slug": type_slug,
|
||||||
|
"personality_prompt": "Test agent.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_type_id = type_resp.json()["id"]
|
||||||
|
|
||||||
|
project_slug = f"list-agents-project-{uuid4().hex[:8]}"
|
||||||
|
project_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={"name": "List Agents Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project_id = project_resp.json()["id"]
|
||||||
|
|
||||||
|
# Spawn multiple agents
|
||||||
|
for i in range(3):
|
||||||
|
await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"name": f"Agent {i + 1}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# List agents
|
||||||
|
list_resp = await e2e_client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list_resp.status_code == 200
|
||||||
|
data = list_resp.json()
|
||||||
|
assert data["pagination"]["total"] == 3
|
||||||
|
assert len(data["data"]) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentLifecycle:
|
||||||
|
"""Test agent lifecycle operations (pause/resume/terminate)."""
|
||||||
|
|
||||||
|
async def test_agent_pause_and_resume(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test pausing and resuming an agent."""
|
||||||
|
# Setup: Create agent type, project, and agent
|
||||||
|
type_slug = f"pause-test-type-{uuid4().hex[:8]}"
|
||||||
|
type_resp = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "Pause Test Type",
|
||||||
|
"slug": type_slug,
|
||||||
|
"personality_prompt": "Test agent.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_type_id = type_resp.json()["id"]
|
||||||
|
|
||||||
|
project_slug = f"pause-test-project-{uuid4().hex[:8]}"
|
||||||
|
project_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={"name": "Pause Test Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project_id = project_resp.json()["id"]
|
||||||
|
|
||||||
|
spawn_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"name": "Pausable Agent",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_id = spawn_resp.json()["id"]
|
||||||
|
assert spawn_resp.json()["status"] == "idle"
|
||||||
|
|
||||||
|
# Pause agent
|
||||||
|
pause_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/pause",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert pause_resp.status_code == 200, f"Failed: {pause_resp.text}"
|
||||||
|
assert pause_resp.json()["status"] == "paused"
|
||||||
|
|
||||||
|
# Resume agent
|
||||||
|
resume_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resume_resp.status_code == 200, f"Failed: {resume_resp.text}"
|
||||||
|
assert resume_resp.json()["status"] == "idle"
|
||||||
|
|
||||||
|
async def test_agent_terminate(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test terminating an agent."""
|
||||||
|
# Setup
|
||||||
|
type_slug = f"terminate-type-{uuid4().hex[:8]}"
|
||||||
|
type_resp = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "Terminate Type",
|
||||||
|
"slug": type_slug,
|
||||||
|
"personality_prompt": "Test agent.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_type_id = type_resp.json()["id"]
|
||||||
|
|
||||||
|
project_slug = f"terminate-project-{uuid4().hex[:8]}"
|
||||||
|
project_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={"name": "Terminate Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project_id = project_resp.json()["id"]
|
||||||
|
|
||||||
|
spawn_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"name": "To Be Terminated",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_id = spawn_resp.json()["id"]
|
||||||
|
|
||||||
|
# Terminate agent (returns MessageResponse, not agent status)
|
||||||
|
terminate_resp = await e2e_client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert terminate_resp.status_code == 200
|
||||||
|
assert "message" in terminate_resp.json()
|
||||||
|
|
||||||
|
# Verify terminated agent cannot be resumed (returns 400 or 422)
|
||||||
|
resume_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resume_resp.status_code in [400, 422] # Invalid transition
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentAccessControl:
|
||||||
|
"""Test agent access control and authorization."""
|
||||||
|
|
||||||
|
async def test_user_cannot_access_other_project_agents(
|
||||||
|
self, e2e_client, e2e_superuser
|
||||||
|
):
|
||||||
|
"""Test that users cannot access agents in projects they don't own."""
|
||||||
|
# Superuser creates agent type
|
||||||
|
type_slug = f"access-type-{uuid4().hex[:8]}"
|
||||||
|
type_resp = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "Access Type",
|
||||||
|
"slug": type_slug,
|
||||||
|
"personality_prompt": "Test agent.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_type_id = type_resp.json()["id"]
|
||||||
|
|
||||||
|
# Superuser creates project and spawns agent
|
||||||
|
project_slug = f"protected-project-{uuid4().hex[:8]}"
|
||||||
|
project_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={"name": "Protected Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project_id = project_resp.json()["id"]
|
||||||
|
|
||||||
|
spawn_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"name": "Protected Agent",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_id = spawn_resp.json()["id"]
|
||||||
|
|
||||||
|
# Create a different user
|
||||||
|
email = f"other-user-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "OtherPass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Other",
|
||||||
|
"last_name": "User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
other_tokens = login_resp.json()
|
||||||
|
|
||||||
|
# Other user tries to access the agent
|
||||||
|
get_resp = await e2e_client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {other_tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should be forbidden or not found
|
||||||
|
assert get_resp.status_code in [403, 404]
|
||||||
|
|
||||||
|
async def test_cannot_spawn_with_inactive_agent_type(
|
||||||
|
self, e2e_client, e2e_superuser
|
||||||
|
):
|
||||||
|
"""Test that agents cannot be spawned from inactive agent types."""
|
||||||
|
# Create agent type
|
||||||
|
type_slug = f"inactive-type-{uuid4().hex[:8]}"
|
||||||
|
type_resp = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "Inactive Type",
|
||||||
|
"slug": type_slug,
|
||||||
|
"personality_prompt": "Test agent.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
"is_active": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_type_id = type_resp.json()["id"]
|
||||||
|
|
||||||
|
# Deactivate the agent type
|
||||||
|
await e2e_client.patch(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={"is_active": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create project
|
||||||
|
project_slug = f"inactive-spawn-project-{uuid4().hex[:8]}"
|
||||||
|
project_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={"name": "Inactive Spawn Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project_id = project_resp.json()["id"]
|
||||||
|
|
||||||
|
# Try to spawn agent with inactive type
|
||||||
|
spawn_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"name": "Should Fail",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 422 is correct for validation errors per REST conventions
|
||||||
|
assert spawn_resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentMetrics:
|
||||||
|
"""Test agent metrics endpoint."""
|
||||||
|
|
||||||
|
async def test_get_agent_metrics(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test retrieving agent metrics."""
|
||||||
|
# Setup
|
||||||
|
type_slug = f"metrics-type-{uuid4().hex[:8]}"
|
||||||
|
type_resp = await e2e_client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"name": "Metrics Type",
|
||||||
|
"slug": type_slug,
|
||||||
|
"personality_prompt": "Test agent.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_type_id = type_resp.json()["id"]
|
||||||
|
|
||||||
|
project_slug = f"metrics-project-{uuid4().hex[:8]}"
|
||||||
|
project_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={"name": "Metrics Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project_id = project_resp.json()["id"]
|
||||||
|
|
||||||
|
spawn_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"name": "Metrics Agent",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent_id = spawn_resp.json()["id"]
|
||||||
|
|
||||||
|
# Get metrics
|
||||||
|
metrics_resp = await e2e_client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/metrics",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metrics_resp.status_code == 200
|
||||||
|
metrics = metrics_resp.json()
|
||||||
|
|
||||||
|
# Verify AgentInstanceMetrics structure
|
||||||
|
assert "total_instances" in metrics
|
||||||
|
assert "active_instances" in metrics
|
||||||
|
assert "idle_instances" in metrics
|
||||||
|
assert "total_tasks_completed" in metrics
|
||||||
|
assert "total_tokens_used" in metrics
|
||||||
|
assert "total_cost_incurred" in metrics
|
||||||
460
backend/tests/e2e/test_mcp_workflows.py
Normal file
460
backend/tests/e2e/test_mcp_workflows.py
Normal file
@@ -0,0 +1,460 @@
|
|||||||
|
"""
|
||||||
|
MCP and Context Engine E2E Workflow Tests.
|
||||||
|
|
||||||
|
Tests complete workflows involving MCP servers and the Context Engine
|
||||||
|
against real PostgreSQL. These tests verify:
|
||||||
|
- MCP server listing and tool discovery
|
||||||
|
- Context engine operations
|
||||||
|
- Admin-only MCP operations with proper authentication
|
||||||
|
- Error handling for MCP operations
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
make test-e2e # Run all E2E tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.e2e,
|
||||||
|
pytest.mark.postgres,
|
||||||
|
pytest.mark.asyncio,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPServerDiscovery:
|
||||||
|
"""Test MCP server listing and discovery workflows."""
|
||||||
|
|
||||||
|
async def test_list_mcp_servers(self, e2e_client):
|
||||||
|
"""Test listing MCP servers returns expected configuration."""
|
||||||
|
response = await e2e_client.get("/api/v1/mcp/servers")
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Should have servers configured
|
||||||
|
assert "servers" in data
|
||||||
|
assert "total" in data
|
||||||
|
assert isinstance(data["servers"], list)
|
||||||
|
|
||||||
|
# Should have at least llm-gateway and knowledge-base
|
||||||
|
server_names = [s["name"] for s in data["servers"]]
|
||||||
|
assert "llm-gateway" in server_names
|
||||||
|
assert "knowledge-base" in server_names
|
||||||
|
|
||||||
|
async def test_list_all_mcp_tools(self, e2e_client):
|
||||||
|
"""Test listing all tools from all MCP servers."""
|
||||||
|
response = await e2e_client.get("/api/v1/mcp/tools")
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "tools" in data
|
||||||
|
assert "total" in data
|
||||||
|
assert isinstance(data["tools"], list)
|
||||||
|
|
||||||
|
async def test_mcp_health_check(self, e2e_client):
|
||||||
|
"""Test MCP health check returns server status."""
|
||||||
|
response = await e2e_client.get("/api/v1/mcp/health")
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "servers" in data
|
||||||
|
assert "healthy_count" in data
|
||||||
|
assert "unhealthy_count" in data
|
||||||
|
assert "total" in data
|
||||||
|
|
||||||
|
async def test_list_circuit_breakers(self, e2e_client):
|
||||||
|
"""Test listing circuit breaker status."""
|
||||||
|
response = await e2e_client.get("/api/v1/mcp/circuit-breakers")
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "circuit_breakers" in data
|
||||||
|
assert isinstance(data["circuit_breakers"], list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPServerTools:
|
||||||
|
"""Test MCP server tool listing."""
|
||||||
|
|
||||||
|
async def test_list_llm_gateway_tools(self, e2e_client):
|
||||||
|
"""Test listing tools from LLM Gateway server."""
|
||||||
|
response = await e2e_client.get("/api/v1/mcp/servers/llm-gateway/tools")
|
||||||
|
|
||||||
|
# May return 200 with tools or 404 if server not connected
|
||||||
|
assert response.status_code in [200, 404, 502]
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
assert "tools" in data
|
||||||
|
assert "total" in data
|
||||||
|
|
||||||
|
async def test_list_knowledge_base_tools(self, e2e_client):
|
||||||
|
"""Test listing tools from Knowledge Base server."""
|
||||||
|
response = await e2e_client.get("/api/v1/mcp/servers/knowledge-base/tools")
|
||||||
|
|
||||||
|
# May return 200 with tools or 404/502 if server not connected
|
||||||
|
assert response.status_code in [200, 404, 502]
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
assert "tools" in data
|
||||||
|
assert "total" in data
|
||||||
|
|
||||||
|
async def test_invalid_server_returns_404(self, e2e_client):
|
||||||
|
"""Test that invalid server name returns 404."""
|
||||||
|
response = await e2e_client.get("/api/v1/mcp/servers/nonexistent-server/tools")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineWorkflows:
|
||||||
|
"""Test Context Engine operations."""
|
||||||
|
|
||||||
|
async def test_context_engine_health(self, e2e_client):
|
||||||
|
"""Test context engine health endpoint."""
|
||||||
|
response = await e2e_client.get("/api/v1/context/health")
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
assert "mcp_connected" in data
|
||||||
|
assert "cache_enabled" in data
|
||||||
|
|
||||||
|
async def test_get_token_budget_claude_sonnet(self, e2e_client):
|
||||||
|
"""Test getting token budget for Claude 3 Sonnet."""
|
||||||
|
response = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["model"] == "claude-3-sonnet"
|
||||||
|
assert "total_tokens" in data
|
||||||
|
assert "system_tokens" in data
|
||||||
|
assert "knowledge_tokens" in data
|
||||||
|
assert "conversation_tokens" in data
|
||||||
|
assert "tool_tokens" in data
|
||||||
|
assert "response_reserve" in data
|
||||||
|
|
||||||
|
# Verify budget allocation makes sense
|
||||||
|
assert data["total_tokens"] > 0
|
||||||
|
total_allocated = (
|
||||||
|
data["system_tokens"]
|
||||||
|
+ data["knowledge_tokens"]
|
||||||
|
+ data["conversation_tokens"]
|
||||||
|
+ data["tool_tokens"]
|
||||||
|
+ data["response_reserve"]
|
||||||
|
)
|
||||||
|
assert total_allocated <= data["total_tokens"]
|
||||||
|
|
||||||
|
async def test_get_token_budget_with_custom_max(self, e2e_client):
|
||||||
|
"""Test getting token budget with custom max tokens."""
|
||||||
|
response = await e2e_client.get(
|
||||||
|
"/api/v1/context/budget/claude-3-sonnet",
|
||||||
|
params={"max_tokens": 50000},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["model"] == "claude-3-sonnet"
|
||||||
|
# Custom max should be respected or capped
|
||||||
|
assert data["total_tokens"] <= 50000
|
||||||
|
|
||||||
|
async def test_count_tokens(self, e2e_client):
|
||||||
|
"""Test token counting endpoint."""
|
||||||
|
response = await e2e_client.post(
|
||||||
|
"/api/v1/context/count-tokens",
|
||||||
|
json={
|
||||||
|
"content": "Hello, this is a test message for token counting.",
|
||||||
|
"model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "token_count" in data
|
||||||
|
assert data["token_count"] > 0
|
||||||
|
assert data["model"] == "claude-3-sonnet"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminMCPOperations:
|
||||||
|
"""Test admin-only MCP operations require authentication."""
|
||||||
|
|
||||||
|
async def test_tool_call_requires_auth(self, e2e_client):
|
||||||
|
"""Test that tool execution requires authentication."""
|
||||||
|
response = await e2e_client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={
|
||||||
|
"server": "llm-gateway",
|
||||||
|
"tool": "count_tokens",
|
||||||
|
"arguments": {"text": "test"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should require authentication
|
||||||
|
assert response.status_code in [401, 403]
|
||||||
|
|
||||||
|
async def test_circuit_reset_requires_auth(self, e2e_client):
|
||||||
|
"""Test that circuit breaker reset requires authentication."""
|
||||||
|
response = await e2e_client.post(
|
||||||
|
"/api/v1/mcp/circuit-breakers/llm-gateway/reset"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code in [401, 403]
|
||||||
|
|
||||||
|
async def test_server_reconnect_requires_auth(self, e2e_client):
|
||||||
|
"""Test that server reconnect requires authentication."""
|
||||||
|
response = await e2e_client.post("/api/v1/mcp/servers/llm-gateway/reconnect")
|
||||||
|
|
||||||
|
assert response.status_code in [401, 403]
|
||||||
|
|
||||||
|
async def test_context_stats_requires_auth(self, e2e_client):
|
||||||
|
"""Test that context stats requires authentication."""
|
||||||
|
response = await e2e_client.get("/api/v1/context/stats")
|
||||||
|
|
||||||
|
assert response.status_code in [401, 403]
|
||||||
|
|
||||||
|
async def test_context_assemble_requires_auth(self, e2e_client):
|
||||||
|
"""Test that context assembly requires authentication."""
|
||||||
|
response = await e2e_client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
json={
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"query": "test query",
|
||||||
|
"model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code in [401, 403]
|
||||||
|
|
||||||
|
async def test_cache_invalidate_requires_auth(self, e2e_client):
|
||||||
|
"""Test that cache invalidation requires authentication."""
|
||||||
|
response = await e2e_client.post("/api/v1/context/cache/invalidate")
|
||||||
|
|
||||||
|
assert response.status_code in [401, 403]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminMCPWithAuthentication:
|
||||||
|
"""Test admin MCP operations with superuser authentication."""
|
||||||
|
|
||||||
|
async def test_superuser_can_get_context_stats(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test that superuser can get context engine stats."""
|
||||||
|
response = await e2e_client.get(
|
||||||
|
"/api/v1/context/stats",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "cache" in data
|
||||||
|
assert "settings" in data
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Requires MCP servers (llm-gateway, knowledge-base) to be running"
|
||||||
|
)
|
||||||
|
async def test_superuser_can_assemble_context(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test that superuser can assemble context."""
|
||||||
|
response = await e2e_client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"project_id": f"test-project-{uuid4().hex[:8]}",
|
||||||
|
"agent_id": f"test-agent-{uuid4().hex[:8]}",
|
||||||
|
"query": "What is the status of the project?",
|
||||||
|
"model": "claude-3-sonnet",
|
||||||
|
"system_prompt": "You are a helpful assistant.",
|
||||||
|
"compress": True,
|
||||||
|
"use_cache": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, f"Failed: {response.text}"
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "content" in data
|
||||||
|
assert "total_tokens" in data
|
||||||
|
assert "context_count" in data
|
||||||
|
assert "budget_used_percent" in data
|
||||||
|
assert "metadata" in data
|
||||||
|
|
||||||
|
async def test_superuser_can_invalidate_cache(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test that superuser can invalidate cache."""
|
||||||
|
response = await e2e_client.post(
|
||||||
|
"/api/v1/context/cache/invalidate",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
params={"project_id": "test-project"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
async def test_regular_user_cannot_access_admin_operations(self, e2e_client):
|
||||||
|
"""Test that regular (non-superuser) cannot access admin operations."""
|
||||||
|
email = f"regular-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "RegularUser123!"
|
||||||
|
|
||||||
|
# Register regular user
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Regular",
|
||||||
|
"last_name": "User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Login
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
# Try to access admin endpoint
|
||||||
|
response = await e2e_client.get(
|
||||||
|
"/api/v1/context/stats",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should be forbidden for non-superuser
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPInputValidation:
|
||||||
|
"""Test input validation for MCP endpoints."""
|
||||||
|
|
||||||
|
async def test_server_name_max_length(self, e2e_client):
|
||||||
|
"""Test that server name has max length validation."""
|
||||||
|
long_name = "a" * 100 # Exceeds 64 char limit
|
||||||
|
|
||||||
|
response = await e2e_client.get(f"/api/v1/mcp/servers/{long_name}/tools")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
async def test_server_name_invalid_characters(self, e2e_client):
|
||||||
|
"""Test that server name rejects invalid characters."""
|
||||||
|
invalid_name = "server@name!invalid"
|
||||||
|
|
||||||
|
response = await e2e_client.get(f"/api/v1/mcp/servers/{invalid_name}/tools")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
async def test_token_count_empty_content(self, e2e_client):
|
||||||
|
"""Test token counting with empty content."""
|
||||||
|
response = await e2e_client.post(
|
||||||
|
"/api/v1/context/count-tokens",
|
||||||
|
json={"content": ""},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty content is valid, should return 0 tokens
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
assert data["token_count"] == 0
|
||||||
|
else:
|
||||||
|
# Or it might be rejected as invalid
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPWorkflowIntegration:
|
||||||
|
"""Test complete MCP workflows end-to-end."""
|
||||||
|
|
||||||
|
async def test_discovery_to_budget_workflow(self, e2e_client):
|
||||||
|
"""Test complete workflow: discover servers -> check budget -> ready for use."""
|
||||||
|
# 1. Discover available servers
|
||||||
|
servers_resp = await e2e_client.get("/api/v1/mcp/servers")
|
||||||
|
assert servers_resp.status_code == 200
|
||||||
|
servers = servers_resp.json()["servers"]
|
||||||
|
assert len(servers) > 0
|
||||||
|
|
||||||
|
# 2. Check context engine health
|
||||||
|
health_resp = await e2e_client.get("/api/v1/context/health")
|
||||||
|
assert health_resp.status_code == 200
|
||||||
|
health = health_resp.json()
|
||||||
|
assert health["status"] == "healthy"
|
||||||
|
|
||||||
|
# 3. Get token budget for a model
|
||||||
|
budget_resp = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
|
||||||
|
assert budget_resp.status_code == 200
|
||||||
|
budget = budget_resp.json()
|
||||||
|
|
||||||
|
# 4. Verify system is ready for context assembly
|
||||||
|
assert budget["total_tokens"] > 0
|
||||||
|
assert health["mcp_connected"] is True
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Requires MCP servers (llm-gateway, knowledge-base) to be running"
|
||||||
|
)
|
||||||
|
async def test_full_context_assembly_workflow(self, e2e_client, e2e_superuser):
|
||||||
|
"""Test complete context assembly workflow with superuser."""
|
||||||
|
project_id = f"e2e-project-{uuid4().hex[:8]}"
|
||||||
|
agent_id = f"e2e-agent-{uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# 1. Check budget before assembly
|
||||||
|
budget_resp = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
|
||||||
|
assert budget_resp.status_code == 200
|
||||||
|
_ = budget_resp.json() # Verify valid response
|
||||||
|
|
||||||
|
# 2. Count tokens in sample content
|
||||||
|
count_resp = await e2e_client.post(
|
||||||
|
"/api/v1/context/count-tokens",
|
||||||
|
json={"content": "This is a test message for context assembly."},
|
||||||
|
)
|
||||||
|
assert count_resp.status_code == 200
|
||||||
|
token_count = count_resp.json()["token_count"]
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
# 3. Assemble context
|
||||||
|
assemble_resp = await e2e_client.post(
|
||||||
|
"/api/v1/context/assemble",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"query": "Summarize the current project status",
|
||||||
|
"model": "claude-3-sonnet",
|
||||||
|
"system_prompt": "You are a project management assistant.",
|
||||||
|
"task_description": "Generate a status report",
|
||||||
|
"conversation_history": [
|
||||||
|
{"role": "user", "content": "What's the project status?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Let me check the current status.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"compress": True,
|
||||||
|
"use_cache": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert assemble_resp.status_code == 200
|
||||||
|
assembled = assemble_resp.json()
|
||||||
|
|
||||||
|
# 4. Verify assembly results
|
||||||
|
assert assembled["total_tokens"] > 0
|
||||||
|
assert assembled["context_count"] > 0
|
||||||
|
assert assembled["budget_used_percent"] > 0
|
||||||
|
assert assembled["budget_used_percent"] <= 100
|
||||||
|
|
||||||
|
# 5. Get stats to verify the operation was recorded
|
||||||
|
stats_resp = await e2e_client.get(
|
||||||
|
"/api/v1/context/stats",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert stats_resp.status_code == 200
|
||||||
684
backend/tests/e2e/test_project_workflows.py
Normal file
684
backend/tests/e2e/test_project_workflows.py
Normal file
@@ -0,0 +1,684 @@
|
|||||||
|
"""
|
||||||
|
Project and Agent E2E Workflow Tests.
|
||||||
|
|
||||||
|
Tests complete project management workflows with real PostgreSQL:
|
||||||
|
- Project CRUD and lifecycle management
|
||||||
|
- Agent spawning and lifecycle
|
||||||
|
- Issue management within projects
|
||||||
|
- Sprint planning and execution
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
make test-e2e # Run all E2E tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.e2e,
|
||||||
|
pytest.mark.postgres,
|
||||||
|
pytest.mark.asyncio,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectCRUDWorkflows:
|
||||||
|
"""Test complete project CRUD workflows."""
|
||||||
|
|
||||||
|
async def test_create_project_workflow(self, e2e_client):
|
||||||
|
"""Test creating a project as authenticated user."""
|
||||||
|
# Register and login
|
||||||
|
email = f"project-owner-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Project",
|
||||||
|
"last_name": "Owner",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
# Create project
|
||||||
|
project_slug = f"test-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"name": "E2E Test Project",
|
||||||
|
"slug": project_slug,
|
||||||
|
"description": "A project for E2E testing",
|
||||||
|
"autonomy_level": "milestone",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert create_resp.status_code == 201, f"Failed: {create_resp.text}"
|
||||||
|
project = create_resp.json()
|
||||||
|
assert project["name"] == "E2E Test Project"
|
||||||
|
assert project["slug"] == project_slug
|
||||||
|
assert project["status"] == "active"
|
||||||
|
assert project["agent_count"] == 0
|
||||||
|
assert project["issue_count"] == 0
|
||||||
|
|
||||||
|
async def test_list_projects_only_shows_owned(self, e2e_client):
|
||||||
|
"""Test that users only see their own projects."""
|
||||||
|
# Create two users with projects
|
||||||
|
users = []
|
||||||
|
for i in range(2):
|
||||||
|
email = f"user-{i}-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": f"User{i}",
|
||||||
|
"last_name": "Test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
# Each user creates their own project
|
||||||
|
project_slug = f"user{i}-project-{uuid4().hex[:8]}"
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"name": f"User {i} Project",
|
||||||
|
"slug": project_slug,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
users.append({"email": email, "tokens": tokens, "slug": project_slug})
|
||||||
|
|
||||||
|
# User 0 should only see their project
|
||||||
|
list_resp = await e2e_client.get(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {users[0]['tokens']['access_token']}"},
|
||||||
|
)
|
||||||
|
assert list_resp.status_code == 200
|
||||||
|
data = list_resp.json()
|
||||||
|
slugs = [p["slug"] for p in data["data"]]
|
||||||
|
assert users[0]["slug"] in slugs
|
||||||
|
assert users[1]["slug"] not in slugs
|
||||||
|
|
||||||
|
async def test_project_lifecycle_pause_resume(self, e2e_client):
|
||||||
|
"""Test pausing and resuming a project."""
|
||||||
|
# Setup user and project
|
||||||
|
email = f"lifecycle-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Lifecycle",
|
||||||
|
"last_name": "Test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
project_slug = f"lifecycle-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "Lifecycle Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project = create_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# Pause the project
|
||||||
|
pause_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/pause",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert pause_resp.status_code == 200
|
||||||
|
assert pause_resp.json()["status"] == "paused"
|
||||||
|
|
||||||
|
# Resume the project
|
||||||
|
resume_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/resume",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert resume_resp.status_code == 200
|
||||||
|
assert resume_resp.json()["status"] == "active"
|
||||||
|
|
||||||
|
async def test_project_archive(self, e2e_client):
|
||||||
|
"""Test archiving a project (soft delete)."""
|
||||||
|
# Setup user and project
|
||||||
|
email = f"archive-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Archive",
|
||||||
|
"last_name": "Test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
project_slug = f"archive-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "Archive Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project = create_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# Archive the project
|
||||||
|
archive_resp = await e2e_client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert archive_resp.status_code == 200
|
||||||
|
assert archive_resp.json()["success"] is True
|
||||||
|
|
||||||
|
# Verify project is archived
|
||||||
|
get_resp = await e2e_client.get(
|
||||||
|
f"/api/v1/projects/{project_id}",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert get_resp.status_code == 200
|
||||||
|
assert get_resp.json()["status"] == "archived"
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueWorkflows:
|
||||||
|
"""Test issue management workflows within projects."""
|
||||||
|
|
||||||
|
async def test_create_and_list_issues(self, e2e_client):
|
||||||
|
"""Test creating and listing issues in a project."""
|
||||||
|
# Setup user and project
|
||||||
|
email = f"issue-test-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Issue",
|
||||||
|
"last_name": "Tester",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
project_slug = f"issue-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "Issue Test Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project = create_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# Create multiple issues
|
||||||
|
issues = []
|
||||||
|
for i in range(3):
|
||||||
|
issue_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": f"Test Issue {i + 1}",
|
||||||
|
"body": f"Description for issue {i + 1}",
|
||||||
|
"priority": ["low", "medium", "high"][i],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert issue_resp.status_code == 201, f"Failed: {issue_resp.text}"
|
||||||
|
issues.append(issue_resp.json())
|
||||||
|
|
||||||
|
# List issues
|
||||||
|
list_resp = await e2e_client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert list_resp.status_code == 200
|
||||||
|
data = list_resp.json()
|
||||||
|
assert data["pagination"]["total"] == 3
|
||||||
|
|
||||||
|
async def test_issue_status_transitions(self, e2e_client):
|
||||||
|
"""Test issue status workflow transitions."""
|
||||||
|
# Setup user and project
|
||||||
|
email = f"status-test-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Status",
|
||||||
|
"last_name": "Tester",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
project_slug = f"status-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "Status Test Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project = create_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# Create issue
|
||||||
|
issue_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Status Workflow Issue",
|
||||||
|
"body": "Testing status transitions",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
issue = issue_resp.json()
|
||||||
|
issue_id = issue["id"]
|
||||||
|
assert issue["status"] == "open"
|
||||||
|
|
||||||
|
# Transition through statuses
|
||||||
|
for new_status in ["in_progress", "in_review", "closed"]:
|
||||||
|
update_resp = await e2e_client.patch(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"status": new_status},
|
||||||
|
)
|
||||||
|
assert update_resp.status_code == 200, f"Failed: {update_resp.text}"
|
||||||
|
assert update_resp.json()["status"] == new_status
|
||||||
|
|
||||||
|
async def test_issue_filtering(self, e2e_client):
|
||||||
|
"""Test issue filtering by status and priority."""
|
||||||
|
# Setup user and project
|
||||||
|
email = f"filter-test-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Filter",
|
||||||
|
"last_name": "Tester",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
project_slug = f"filter-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "Filter Test Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project = create_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# Create issues with different priorities
|
||||||
|
for priority in ["low", "medium", "high"]:
|
||||||
|
await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": f"{priority.title()} Priority Issue",
|
||||||
|
"priority": priority,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by high priority
|
||||||
|
filter_resp = await e2e_client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
params={"priority": "high"},
|
||||||
|
)
|
||||||
|
assert filter_resp.status_code == 200
|
||||||
|
data = filter_resp.json()
|
||||||
|
assert data["pagination"]["total"] == 1
|
||||||
|
assert data["data"][0]["priority"] == "high"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintWorkflows:
|
||||||
|
"""Test sprint planning and execution workflows."""
|
||||||
|
|
||||||
|
async def test_sprint_lifecycle(self, e2e_client):
|
||||||
|
"""Test complete sprint lifecycle: plan -> start -> complete."""
|
||||||
|
# Setup user and project
|
||||||
|
email = f"sprint-test-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Sprint",
|
||||||
|
"last_name": "Tester",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
project_slug = f"sprint-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "Sprint Test Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project = create_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# Create sprint
|
||||||
|
today = date.today()
|
||||||
|
sprint_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/sprints",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"name": "Sprint 1",
|
||||||
|
"number": 1,
|
||||||
|
"goal": "Complete initial features",
|
||||||
|
"start_date": today.isoformat(),
|
||||||
|
"end_date": (today + timedelta(days=14)).isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert sprint_resp.status_code == 201, f"Failed: {sprint_resp.text}"
|
||||||
|
sprint = sprint_resp.json()
|
||||||
|
sprint_id = sprint["id"]
|
||||||
|
assert sprint["status"] == "planned"
|
||||||
|
|
||||||
|
# Start sprint
|
||||||
|
start_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/start",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert start_resp.status_code == 200, f"Failed: {start_resp.text}"
|
||||||
|
assert start_resp.json()["status"] == "active"
|
||||||
|
|
||||||
|
# Complete sprint
|
||||||
|
complete_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/complete",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert complete_resp.status_code == 200, f"Failed: {complete_resp.text}"
|
||||||
|
assert complete_resp.json()["status"] == "completed"
|
||||||
|
|
||||||
|
async def test_add_issues_to_sprint(self, e2e_client):
|
||||||
|
"""Test adding issues to a sprint."""
|
||||||
|
# Setup user and project
|
||||||
|
email = f"sprint-issues-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "SprintIssues",
|
||||||
|
"last_name": "Tester",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
login_resp = await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
tokens = login_resp.json()
|
||||||
|
|
||||||
|
project_slug = f"sprint-issues-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "Sprint Issues Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project = create_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# Create sprint
|
||||||
|
today = date.today()
|
||||||
|
sprint_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/sprints",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"name": "Sprint 1",
|
||||||
|
"number": 1,
|
||||||
|
"start_date": today.isoformat(),
|
||||||
|
"end_date": (today + timedelta(days=14)).isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert sprint_resp.status_code == 201, f"Failed: {sprint_resp.text}"
|
||||||
|
sprint = sprint_resp.json()
|
||||||
|
sprint_id = sprint["id"]
|
||||||
|
|
||||||
|
# Create issue
|
||||||
|
issue_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Sprint Issue",
|
||||||
|
"story_points": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
issue = issue_resp.json()
|
||||||
|
issue_id = issue["id"]
|
||||||
|
|
||||||
|
# Add issue to sprint
|
||||||
|
add_resp = await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
params={"issue_id": issue_id},
|
||||||
|
)
|
||||||
|
assert add_resp.status_code == 200, f"Failed: {add_resp.text}"
|
||||||
|
|
||||||
|
# Verify issue is in sprint
|
||||||
|
issue_check = await e2e_client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert issue_check.json()["sprint_id"] == sprint_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrossEntityValidation:
|
||||||
|
"""Test validation across related entities."""
|
||||||
|
|
||||||
|
async def test_cannot_access_other_users_project(self, e2e_client):
|
||||||
|
"""Test that users cannot access projects they don't own."""
|
||||||
|
# Create two users
|
||||||
|
owner_email = f"owner-{uuid4().hex[:8]}@example.com"
|
||||||
|
other_email = f"other-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
# Register owner
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": owner_email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Owner",
|
||||||
|
"last_name": "User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
owner_tokens = (
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": owner_email, "password": password},
|
||||||
|
)
|
||||||
|
).json()
|
||||||
|
|
||||||
|
# Register other user
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": other_email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Other",
|
||||||
|
"last_name": "User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
other_tokens = (
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": other_email, "password": password},
|
||||||
|
)
|
||||||
|
).json()
|
||||||
|
|
||||||
|
# Owner creates project
|
||||||
|
project_slug = f"private-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {owner_tokens['access_token']}"},
|
||||||
|
json={"name": "Private Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project = create_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# Other user tries to access
|
||||||
|
access_resp = await e2e_client.get(
|
||||||
|
f"/api/v1/projects/{project_id}",
|
||||||
|
headers={"Authorization": f"Bearer {other_tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert access_resp.status_code == 403
|
||||||
|
|
||||||
|
async def test_duplicate_project_slug_rejected(self, e2e_client):
|
||||||
|
"""Test that duplicate project slugs are rejected."""
|
||||||
|
email = f"dup-test-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Dup",
|
||||||
|
"last_name": "Tester",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokens = (
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
).json()
|
||||||
|
|
||||||
|
slug = f"unique-slug-{uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# First creation should succeed
|
||||||
|
resp1 = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "First Project", "slug": slug},
|
||||||
|
)
|
||||||
|
assert resp1.status_code == 201
|
||||||
|
|
||||||
|
# Second creation with same slug should fail
|
||||||
|
resp2 = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "Second Project", "slug": slug},
|
||||||
|
)
|
||||||
|
assert resp2.status_code == 409 # Conflict
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueStats:
|
||||||
|
"""Test issue statistics endpoints."""
|
||||||
|
|
||||||
|
async def test_issue_stats_aggregation(self, e2e_client):
|
||||||
|
"""Test that issue stats are correctly aggregated."""
|
||||||
|
email = f"stats-test-{uuid4().hex[:8]}@example.com"
|
||||||
|
password = "SecurePass123!"
|
||||||
|
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password,
|
||||||
|
"first_name": "Stats",
|
||||||
|
"last_name": "Tester",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokens = (
|
||||||
|
await e2e_client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
).json()
|
||||||
|
|
||||||
|
project_slug = f"stats-project-{uuid4().hex[:8]}"
|
||||||
|
create_resp = await e2e_client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={"name": "Stats Project", "slug": project_slug},
|
||||||
|
)
|
||||||
|
project = create_resp.json()
|
||||||
|
project_id = project["id"]
|
||||||
|
|
||||||
|
# Create issues with different priorities and story points
|
||||||
|
await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "High Priority",
|
||||||
|
"priority": "high",
|
||||||
|
"story_points": 8,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await e2e_client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Low Priority",
|
||||||
|
"priority": "low",
|
||||||
|
"story_points": 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get stats
|
||||||
|
stats_resp = await e2e_client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/stats",
|
||||||
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
|
)
|
||||||
|
assert stats_resp.status_code == 200
|
||||||
|
stats = stats_resp.json()
|
||||||
|
assert stats["total"] == 2
|
||||||
|
assert stats["total_story_points"] == 10
|
||||||
1
backend/tests/integration/__init__.py
Normal file
1
backend/tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Integration tests that require the full stack to be running."""
|
||||||
322
backend/tests/integration/test_mcp_integration.py
Normal file
322
backend/tests/integration/test_mcp_integration.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for MCP server connectivity.
|
||||||
|
|
||||||
|
These tests require the full stack to be running:
|
||||||
|
- docker compose -f docker-compose.dev.yml up
|
||||||
|
|
||||||
|
Run with:
|
||||||
|
pytest tests/integration/ -v --integration
|
||||||
|
|
||||||
|
Or skip with:
|
||||||
|
pytest tests/ -v --ignore=tests/integration/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Skip all tests in this module if not running integration tests
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true",
|
||||||
|
reason="Integration tests require RUN_INTEGRATION_TESTS=true and running stack",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration from environment
|
||||||
|
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
|
||||||
|
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8001")
|
||||||
|
KNOWLEDGE_BASE_URL = os.getenv("KNOWLEDGE_BASE_URL", "http://localhost:8002")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPServerHealth:
|
||||||
|
"""Test that MCP servers are healthy and reachable."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_gateway_health(self) -> None:
|
||||||
|
"""Test LLM Gateway health endpoint."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(f"{LLM_GATEWAY_URL}/health", timeout=10.0)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data.get("status") == "healthy" or data.get("healthy") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_knowledge_base_health(self) -> None:
|
||||||
|
"""Test Knowledge Base health endpoint."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(f"{KNOWLEDGE_BASE_URL}/health", timeout=10.0)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data.get("status") == "healthy" or data.get("healthy") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_backend_health(self) -> None:
|
||||||
|
"""Test Backend health endpoint."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(f"{BACKEND_URL}/health", timeout=10.0)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClientManagerIntegration:
|
||||||
|
"""Test MCPClientManager can connect to real MCP servers."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_servers_list(self) -> None:
|
||||||
|
"""Test that backend can list MCP servers via API."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# This endpoint lists configured MCP servers
|
||||||
|
response = await client.get(
|
||||||
|
f"{BACKEND_URL}/api/v1/mcp/servers",
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
# Should return 200 or 401 (if auth required)
|
||||||
|
assert response.status_code in [200, 401, 403]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_health_check_endpoint(self) -> None:
|
||||||
|
"""Test backend's MCP health check endpoint."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{BACKEND_URL}/api/v1/mcp/health",
|
||||||
|
timeout=30.0, # MCP health checks can take time
|
||||||
|
)
|
||||||
|
# Should return 200 or 401 (if auth required)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
# Check structure
|
||||||
|
assert "servers" in data or "healthy" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMGatewayIntegration:
|
||||||
|
"""Test LLM Gateway MCP server functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_models(self) -> None:
|
||||||
|
"""Test that LLM Gateway can list available models."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# MCP servers use JSON-RPC 2.0 protocol at /mcp endpoint
|
||||||
|
response = await client.post(
|
||||||
|
f"{LLM_GATEWAY_URL}/mcp",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "tools/list",
|
||||||
|
"params": {},
|
||||||
|
},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
# Should have tools listed
|
||||||
|
assert "result" in data or "error" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens(self) -> None:
|
||||||
|
"""Test token counting functionality."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{LLM_GATEWAY_URL}/mcp",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "count_tokens",
|
||||||
|
"arguments": {
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"text": "Hello, world!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
# Check for result or error
|
||||||
|
if "result" in data:
|
||||||
|
assert "content" in data["result"] or "token_count" in str(
|
||||||
|
data["result"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeBaseIntegration:
|
||||||
|
"""Test Knowledge Base MCP server functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools(self) -> None:
|
||||||
|
"""Test that Knowledge Base can list available tools."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Knowledge Base uses GET /mcp/tools for listing
|
||||||
|
response = await client.get(
|
||||||
|
f"{KNOWLEDGE_BASE_URL}/mcp/tools",
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "tools" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_knowledge_empty(self) -> None:
|
||||||
|
"""Test search on empty knowledge base."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Knowledge Base uses direct tool name as method
|
||||||
|
response = await client.post(
|
||||||
|
f"{KNOWLEDGE_BASE_URL}/mcp",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "search_knowledge",
|
||||||
|
"params": {
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"query": "test query",
|
||||||
|
"limit": 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
# Should return empty results or error for no collection
|
||||||
|
assert "result" in data or "error" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndToEndMCPFlow:
|
||||||
|
"""End-to-end tests for MCP integration flow."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_mcp_discovery_flow(self) -> None:
|
||||||
|
"""Test the full flow of discovering and listing MCP tools."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# 1. Check backend health
|
||||||
|
health = await client.get(f"{BACKEND_URL}/health", timeout=10.0)
|
||||||
|
assert health.status_code == 200
|
||||||
|
|
||||||
|
# 2. Check LLM Gateway health
|
||||||
|
llm_health = await client.get(f"{LLM_GATEWAY_URL}/health", timeout=10.0)
|
||||||
|
assert llm_health.status_code == 200
|
||||||
|
|
||||||
|
# 3. Check Knowledge Base health
|
||||||
|
kb_health = await client.get(f"{KNOWLEDGE_BASE_URL}/health", timeout=10.0)
|
||||||
|
assert kb_health.status_code == 200
|
||||||
|
|
||||||
|
# 4. List tools from LLM Gateway (uses JSON-RPC at /mcp)
|
||||||
|
llm_tools = await client.post(
|
||||||
|
f"{LLM_GATEWAY_URL}/mcp",
|
||||||
|
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
assert llm_tools.status_code == 200
|
||||||
|
|
||||||
|
# 5. List tools from Knowledge Base (uses GET /mcp/tools)
|
||||||
|
kb_tools = await client.get(
|
||||||
|
f"{KNOWLEDGE_BASE_URL}/mcp/tools",
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
assert kb_tools.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineIntegration:
|
||||||
|
"""Test Context Engine integration with MCP servers."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_health_endpoint(self) -> None:
|
||||||
|
"""Test context engine health endpoint."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{BACKEND_URL}/api/v1/context/health",
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data.get("status") == "healthy"
|
||||||
|
assert "mcp_connected" in data
|
||||||
|
assert "cache_enabled" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_budget_endpoint(self) -> None:
|
||||||
|
"""Test token budget endpoint."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{BACKEND_URL}/api/v1/context/budget/claude-3-sonnet",
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "total_tokens" in data
|
||||||
|
assert "system_tokens" in data
|
||||||
|
assert data.get("model") == "claude-3-sonnet"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_assembly_requires_auth(self) -> None:
|
||||||
|
"""Test that context assembly requires authentication."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{BACKEND_URL}/api/v1/context/assemble",
|
||||||
|
json={
|
||||||
|
"project_id": "test-project",
|
||||||
|
"agent_id": "test-agent",
|
||||||
|
"query": "test query",
|
||||||
|
"model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
# Should require auth
|
||||||
|
assert response.status_code in [401, 403]
|
||||||
|
|
||||||
|
|
||||||
|
def run_quick_health_check() -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Quick synchronous health check for all services.
|
||||||
|
Can be run standalone to verify the stack is up.
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
results: dict[str, Any] = {
|
||||||
|
"backend": False,
|
||||||
|
"llm_gateway": False,
|
||||||
|
"knowledge_base": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=5.0) as client:
|
||||||
|
try:
|
||||||
|
r = client.get(f"{BACKEND_URL}/health")
|
||||||
|
results["backend"] = r.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
r = client.get(f"{LLM_GATEWAY_URL}/health")
|
||||||
|
results["llm_gateway"] = r.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
r = client.get(f"{KNOWLEDGE_BASE_URL}/health")
|
||||||
|
results["knowledge_base"] = r.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Checking service health...")
|
||||||
|
results = run_quick_health_check()
|
||||||
|
for service, healthy in results.items():
|
||||||
|
status = "OK" if healthy else "FAILED"
|
||||||
|
print(f" {service}: {status}")
|
||||||
|
|
||||||
|
all_healthy = all(results.values())
|
||||||
|
if all_healthy:
|
||||||
|
print("\nAll services healthy! Run integration tests with:")
|
||||||
|
print(" RUN_INTEGRATION_TESTS=true pytest tests/integration/ -v")
|
||||||
|
else:
|
||||||
|
print("\nSome services are not healthy. Start the stack with:")
|
||||||
|
print(" make dev")
|
||||||
@@ -72,7 +72,7 @@ class TestContextSettings:
|
|||||||
"""Test performance settings."""
|
"""Test performance settings."""
|
||||||
settings = ContextSettings()
|
settings = ContextSettings()
|
||||||
|
|
||||||
assert settings.max_assembly_time_ms == 100
|
assert settings.max_assembly_time_ms == 2000
|
||||||
assert settings.parallel_scoring is True
|
assert settings.parallel_scoring is True
|
||||||
assert settings.max_parallel_scores == 10
|
assert settings.max_parallel_scores == 10
|
||||||
|
|
||||||
|
|||||||
@@ -758,3 +758,136 @@ class TestBaseScorer:
|
|||||||
# Boundaries
|
# Boundaries
|
||||||
assert scorer.normalize_score(0.0) == 0.0
|
assert scorer.normalize_score(0.0) == 0.0
|
||||||
assert scorer.normalize_score(1.0) == 1.0
|
assert scorer.normalize_score(1.0) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompositeScorerEdgeCases:
|
||||||
|
"""Tests for CompositeScorer edge cases and lock management."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_zero_weights(self) -> None:
|
||||||
|
"""Test scoring when all weights are zero."""
|
||||||
|
scorer = CompositeScorer(
|
||||||
|
relevance_weight=0.0,
|
||||||
|
recency_weight=0.0,
|
||||||
|
priority_weight=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 0.0 when total weight is 0
|
||||||
|
score = await scorer.score(context, "test query")
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_batch_sequential(self) -> None:
|
||||||
|
"""Test batch scoring in sequential mode (parallel=False)."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Content 1",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Content 2",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use parallel=False to cover the sequential path
|
||||||
|
scored = await scorer.score_batch(contexts, "query", parallel=False)
|
||||||
|
|
||||||
|
assert len(scored) == 2
|
||||||
|
assert scored[0].relevance_score == 0.8
|
||||||
|
assert scored[1].relevance_score == 0.5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_fast_path_reuse(self) -> None:
|
||||||
|
"""Test that existing locks are reused via fast path."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First access creates the lock
|
||||||
|
lock1 = await scorer._get_context_lock(context.id)
|
||||||
|
|
||||||
|
# Second access should hit the fast path (lock exists in dict)
|
||||||
|
lock2 = await scorer._get_context_lock(context.id)
|
||||||
|
|
||||||
|
assert lock2 is lock1 # Same lock object returned
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_cleanup_when_limit_reached(self) -> None:
|
||||||
|
"""Test that old locks are cleaned up when limit is reached."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Create scorer with very low max_locks to trigger cleanup
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
scorer._max_locks = 3
|
||||||
|
scorer._lock_ttl = 0.1 # 100ms TTL
|
||||||
|
|
||||||
|
# Create locks for several context IDs
|
||||||
|
context_ids = [f"ctx-{i}" for i in range(5)]
|
||||||
|
|
||||||
|
# Get locks for first 3 contexts (fill up to limit)
|
||||||
|
for ctx_id in context_ids[:3]:
|
||||||
|
await scorer._get_context_lock(ctx_id)
|
||||||
|
|
||||||
|
# Wait for TTL to expire
|
||||||
|
time.sleep(0.15)
|
||||||
|
|
||||||
|
# Getting a lock for a new context should trigger cleanup
|
||||||
|
await scorer._get_context_lock(context_ids[3])
|
||||||
|
|
||||||
|
# Some old locks should have been cleaned up
|
||||||
|
# The exact number depends on cleanup logic
|
||||||
|
assert len(scorer._context_locks) <= scorer._max_locks + 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_cleanup_preserves_held_locks(self) -> None:
|
||||||
|
"""Test that cleanup doesn't remove locks that are currently held."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
scorer._max_locks = 2
|
||||||
|
scorer._lock_ttl = 0.05 # 50ms TTL
|
||||||
|
|
||||||
|
# Get and hold lock1
|
||||||
|
lock1 = await scorer._get_context_lock("ctx-1")
|
||||||
|
async with lock1:
|
||||||
|
# While holding lock1, add more locks
|
||||||
|
await scorer._get_context_lock("ctx-2")
|
||||||
|
time.sleep(0.1) # Let TTL expire
|
||||||
|
# Adding another should trigger cleanup
|
||||||
|
await scorer._get_context_lock("ctx-3")
|
||||||
|
|
||||||
|
# lock1 should still exist (it's held)
|
||||||
|
assert any(lock is lock1 for lock, _ in scorer._context_locks.values())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_lock_acquisition_double_check(self) -> None:
|
||||||
|
"""Test that concurrent lock acquisition uses double-check pattern."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context_id = "test-context-id"
|
||||||
|
|
||||||
|
# Simulate concurrent lock acquisition
|
||||||
|
async def get_lock():
|
||||||
|
return await scorer._get_context_lock(context_id)
|
||||||
|
|
||||||
|
locks = await asyncio.gather(*[get_lock() for _ in range(10)])
|
||||||
|
|
||||||
|
# All should get the same lock (double-check pattern ensures this)
|
||||||
|
assert all(lock is locks[0] for lock in locks)
|
||||||
|
|||||||
989
backend/tests/services/safety/test_audit.py
Normal file
989
backend/tests/services/safety/test_audit.py
Normal file
@@ -0,0 +1,989 @@
|
|||||||
|
"""
|
||||||
|
Tests for Audit Logger.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- AuditLogger initialization and lifecycle
|
||||||
|
- Event logging and hash chain
|
||||||
|
- Query and filtering
|
||||||
|
- Retention policy enforcement
|
||||||
|
- Handler management
|
||||||
|
- Singleton pattern
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.audit.logger import (
|
||||||
|
AuditLogger,
|
||||||
|
get_audit_logger,
|
||||||
|
reset_audit_logger,
|
||||||
|
shutdown_audit_logger,
|
||||||
|
)
|
||||||
|
from app.services.safety.models import (
|
||||||
|
ActionMetadata,
|
||||||
|
ActionRequest,
|
||||||
|
ActionType,
|
||||||
|
AuditEventType,
|
||||||
|
AutonomyLevel,
|
||||||
|
SafetyDecision,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerInit:
|
||||||
|
"""Tests for AuditLogger initialization."""
|
||||||
|
|
||||||
|
def test_init_default_values(self):
|
||||||
|
"""Test initialization with default values."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
assert logger._flush_interval == 10.0
|
||||||
|
assert logger._enable_hash_chain is True
|
||||||
|
assert logger._last_hash is None
|
||||||
|
assert logger._running is False
|
||||||
|
|
||||||
|
def test_init_custom_values(self):
|
||||||
|
"""Test initialization with custom values."""
|
||||||
|
logger = AuditLogger(
|
||||||
|
max_buffer_size=500,
|
||||||
|
flush_interval_seconds=5.0,
|
||||||
|
enable_hash_chain=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert logger._flush_interval == 5.0
|
||||||
|
assert logger._enable_hash_chain is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerLifecycle:
|
||||||
|
"""Tests for AuditLogger start/stop."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_creates_flush_task(self):
|
||||||
|
"""Test that start creates the periodic flush task."""
|
||||||
|
logger = AuditLogger(flush_interval_seconds=1.0)
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
|
||||||
|
assert logger._running is True
|
||||||
|
assert logger._flush_task is not None
|
||||||
|
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_idempotent(self):
|
||||||
|
"""Test that multiple starts don't create multiple tasks."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
task1 = logger._flush_task
|
||||||
|
|
||||||
|
await logger.start() # Second start
|
||||||
|
task2 = logger._flush_task
|
||||||
|
|
||||||
|
assert task1 is task2
|
||||||
|
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_cancels_task_and_flushes(self):
|
||||||
|
"""Test that stop cancels the task and flushes events."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
|
||||||
|
# Add an event
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED, agent_id="agent-1")
|
||||||
|
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
assert logger._running is False
|
||||||
|
# Event should be flushed
|
||||||
|
assert len(logger._persisted) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_without_start(self):
|
||||||
|
"""Test stopping without starting doesn't error."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
await logger.stop() # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerLog:
|
||||||
|
"""Tests for the log method."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def logger(self):
|
||||||
|
"""Create a logger instance."""
|
||||||
|
return AuditLogger(enable_hash_chain=True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_creates_event(self, logger):
|
||||||
|
"""Test logging creates an event."""
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_REQUESTED
|
||||||
|
assert event.agent_id == "agent-1"
|
||||||
|
assert event.project_id == "proj-1"
|
||||||
|
assert event.id is not None
|
||||||
|
assert event.timestamp is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_adds_hash_chain(self, logger):
|
||||||
|
"""Test logging adds hash chain."""
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "_hash" in event.details
|
||||||
|
assert "_prev_hash" in event.details
|
||||||
|
assert event.details["_prev_hash"] is None # First event
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_chain_links_events(self, logger):
|
||||||
|
"""Test hash chain links events."""
|
||||||
|
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
event2 = await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
assert event2.details["_prev_hash"] == event1.details["_hash"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_without_hash_chain(self):
|
||||||
|
"""Test logging without hash chain."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
assert "_hash" not in event.details
|
||||||
|
assert "_prev_hash" not in event.details
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_with_all_fields(self, logger):
|
||||||
|
"""Test logging with all optional fields."""
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
action_id="action-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
session_id="sess-1",
|
||||||
|
user_id="user-1",
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
details={"custom": "data"},
|
||||||
|
correlation_id="corr-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.agent_id == "agent-1"
|
||||||
|
assert event.action_id == "action-1"
|
||||||
|
assert event.project_id == "proj-1"
|
||||||
|
assert event.session_id == "sess-1"
|
||||||
|
assert event.user_id == "user-1"
|
||||||
|
assert event.decision == SafetyDecision.ALLOW
|
||||||
|
assert event.details["custom"] == "data"
|
||||||
|
assert event.correlation_id == "corr-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_buffers_event(self, logger):
|
||||||
|
"""Test logging adds event to buffer."""
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
assert len(logger._buffer) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerConvenienceMethods:
|
||||||
|
"""Tests for convenience logging methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def logger(self):
|
||||||
|
"""Create a logger instance."""
|
||||||
|
return AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
def action(self):
|
||||||
|
"""Create a test action request."""
|
||||||
|
metadata = ActionMetadata(
|
||||||
|
agent_id="agent-1",
|
||||||
|
session_id="sess-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
autonomy_level=AutonomyLevel.MILESTONE,
|
||||||
|
user_id="user-1",
|
||||||
|
correlation_id="corr-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ActionRequest(
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
tool_name="file_write",
|
||||||
|
arguments={"path": "/test.txt"},
|
||||||
|
resource="/test.txt",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_action_request_allowed(self, logger, action):
|
||||||
|
"""Test logging allowed action request."""
|
||||||
|
event = await logger.log_action_request(
|
||||||
|
action,
|
||||||
|
SafetyDecision.ALLOW,
|
||||||
|
reasons=["Within budget"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_VALIDATED
|
||||||
|
assert event.decision == SafetyDecision.ALLOW
|
||||||
|
assert event.details["reasons"] == ["Within budget"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_action_request_denied(self, logger, action):
|
||||||
|
"""Test logging denied action request."""
|
||||||
|
event = await logger.log_action_request(
|
||||||
|
action,
|
||||||
|
SafetyDecision.DENY,
|
||||||
|
reasons=["Rate limit exceeded"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_DENIED
|
||||||
|
assert event.decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_action_executed_success(self, logger, action):
|
||||||
|
"""Test logging successful action execution."""
|
||||||
|
event = await logger.log_action_executed(
|
||||||
|
action,
|
||||||
|
success=True,
|
||||||
|
execution_time_ms=50.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_EXECUTED
|
||||||
|
assert event.details["success"] is True
|
||||||
|
assert event.details["execution_time_ms"] == 50.0
|
||||||
|
assert event.details["error"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_action_executed_failure(self, logger, action):
|
||||||
|
"""Test logging failed action execution."""
|
||||||
|
event = await logger.log_action_executed(
|
||||||
|
action,
|
||||||
|
success=False,
|
||||||
|
execution_time_ms=100.0,
|
||||||
|
error="File not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_FAILED
|
||||||
|
assert event.details["success"] is False
|
||||||
|
assert event.details["error"] == "File not found"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_approval_event(self, logger, action):
|
||||||
|
"""Test logging approval event."""
|
||||||
|
event = await logger.log_approval_event(
|
||||||
|
AuditEventType.APPROVAL_GRANTED,
|
||||||
|
approval_id="approval-1",
|
||||||
|
action=action,
|
||||||
|
decided_by="admin",
|
||||||
|
reason="Approved by admin",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.APPROVAL_GRANTED
|
||||||
|
assert event.details["approval_id"] == "approval-1"
|
||||||
|
assert event.details["decided_by"] == "admin"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_budget_event(self, logger):
|
||||||
|
"""Test logging budget event."""
|
||||||
|
event = await logger.log_budget_event(
|
||||||
|
AuditEventType.BUDGET_WARNING,
|
||||||
|
agent_id="agent-1",
|
||||||
|
scope="daily",
|
||||||
|
current_usage=8000.0,
|
||||||
|
limit=10000.0,
|
||||||
|
unit="tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.BUDGET_WARNING
|
||||||
|
assert event.details["scope"] == "daily"
|
||||||
|
assert event.details["usage_percent"] == 80.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_budget_event_zero_limit(self, logger):
|
||||||
|
"""Test logging budget event with zero limit."""
|
||||||
|
event = await logger.log_budget_event(
|
||||||
|
AuditEventType.BUDGET_WARNING,
|
||||||
|
agent_id="agent-1",
|
||||||
|
scope="daily",
|
||||||
|
current_usage=100.0,
|
||||||
|
limit=0.0, # Zero limit
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details["usage_percent"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_emergency_stop(self, logger):
|
||||||
|
"""Test logging emergency stop."""
|
||||||
|
event = await logger.log_emergency_stop(
|
||||||
|
stop_type="global",
|
||||||
|
triggered_by="admin",
|
||||||
|
reason="Security incident",
|
||||||
|
affected_agents=["agent-1", "agent-2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.EMERGENCY_STOP
|
||||||
|
assert event.details["stop_type"] == "global"
|
||||||
|
assert event.details["affected_agents"] == ["agent-1", "agent-2"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerFlush:
|
||||||
|
"""Tests for flush functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flush_persists_events(self):
|
||||||
|
"""Test flush moves events to persisted storage."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
assert len(logger._buffer) == 2
|
||||||
|
assert len(logger._persisted) == 0
|
||||||
|
|
||||||
|
count = await logger.flush()
|
||||||
|
|
||||||
|
assert count == 2
|
||||||
|
assert len(logger._buffer) == 0
|
||||||
|
assert len(logger._persisted) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flush_empty_buffer(self):
|
||||||
|
"""Test flush with empty buffer."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
count = await logger.flush()
|
||||||
|
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerQuery:
|
||||||
|
"""Tests for query functionality."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def logger_with_events(self):
|
||||||
|
"""Create a logger with some test events."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
# Add various events
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_DENIED,
|
||||||
|
agent_id="agent-2",
|
||||||
|
project_id="proj-2",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.BUDGET_WARNING,
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_all(self, logger_with_events):
|
||||||
|
"""Test querying all events."""
|
||||||
|
events = await logger_with_events.query()
|
||||||
|
|
||||||
|
assert len(events) == 4
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_event_type(self, logger_with_events):
|
||||||
|
"""Test filtering by event type."""
|
||||||
|
events = await logger_with_events.query(
|
||||||
|
event_types=[AuditEventType.ACTION_REQUESTED]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].event_type == AuditEventType.ACTION_REQUESTED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_agent_id(self, logger_with_events):
|
||||||
|
"""Test filtering by agent ID."""
|
||||||
|
events = await logger_with_events.query(agent_id="agent-1")
|
||||||
|
|
||||||
|
assert len(events) == 3
|
||||||
|
assert all(e.agent_id == "agent-1" for e in events)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_project_id(self, logger_with_events):
|
||||||
|
"""Test filtering by project ID."""
|
||||||
|
events = await logger_with_events.query(project_id="proj-2")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_user_id(self, logger_with_events):
|
||||||
|
"""Test filtering by user ID."""
|
||||||
|
events = await logger_with_events.query(user_id="user-1")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].event_type == AuditEventType.BUDGET_WARNING
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_limit(self, logger_with_events):
|
||||||
|
"""Test query with limit."""
|
||||||
|
events = await logger_with_events.query(limit=2)
|
||||||
|
|
||||||
|
assert len(events) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_offset(self, logger_with_events):
|
||||||
|
"""Test query with offset."""
|
||||||
|
all_events = await logger_with_events.query()
|
||||||
|
offset_events = await logger_with_events.query(offset=2)
|
||||||
|
|
||||||
|
assert len(offset_events) == 2
|
||||||
|
assert offset_events[0] == all_events[2]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_time_range(self):
|
||||||
|
"""Test filtering by time range."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
now = datetime.utcnow()
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
# Query with start time
|
||||||
|
events = await logger.query(
|
||||||
|
start_time=now - timedelta(seconds=1),
|
||||||
|
end_time=now + timedelta(seconds=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_correlation_id(self):
|
||||||
|
"""Test filtering by correlation ID."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
correlation_id="corr-123",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
correlation_id="corr-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
events = await logger.query(correlation_id="corr-123")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].correlation_id == "corr-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_combined_filters(self, logger_with_events):
|
||||||
|
"""Test combined filters."""
|
||||||
|
events = await logger_with_events.query(
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
event_types=[
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(events) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_action_history(self, logger_with_events):
|
||||||
|
"""Test get_action_history method."""
|
||||||
|
events = await logger_with_events.get_action_history("agent-1")
|
||||||
|
|
||||||
|
# Should only return action-related events
|
||||||
|
assert len(events) == 2
|
||||||
|
assert all(
|
||||||
|
e.event_type
|
||||||
|
in {AuditEventType.ACTION_REQUESTED, AuditEventType.ACTION_EXECUTED}
|
||||||
|
for e in events
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerIntegrity:
|
||||||
|
"""Tests for hash chain integrity verification."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_integrity_valid(self):
|
||||||
|
"""Test integrity verification with valid chain."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=True)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
is_valid, issues = await logger.verify_integrity()
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert len(issues) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_integrity_disabled(self):
|
||||||
|
"""Test integrity verification when hash chain disabled."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
is_valid, issues = await logger.verify_integrity()
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert len(issues) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_integrity_broken_chain(self):
|
||||||
|
"""Test integrity verification with broken chain."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=True)
|
||||||
|
|
||||||
|
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
# Tamper with first event's hash
|
||||||
|
event1.details["_hash"] = "tampered_hash"
|
||||||
|
|
||||||
|
is_valid, issues = await logger.verify_integrity()
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert len(issues) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerHandlers:
|
||||||
|
"""Tests for event handler management."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_sync_handler(self):
|
||||||
|
"""Test adding synchronous handler."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
events_received = []
|
||||||
|
|
||||||
|
def handler(event):
|
||||||
|
events_received.append(event)
|
||||||
|
|
||||||
|
logger.add_handler(handler)
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
assert len(events_received) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_async_handler(self):
|
||||||
|
"""Test adding async handler."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
events_received = []
|
||||||
|
|
||||||
|
async def handler(event):
|
||||||
|
events_received.append(event)
|
||||||
|
|
||||||
|
logger.add_handler(handler)
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
assert len(events_received) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_handler(self):
|
||||||
|
"""Test removing handler."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
events_received = []
|
||||||
|
|
||||||
|
def handler(event):
|
||||||
|
events_received.append(event)
|
||||||
|
|
||||||
|
logger.add_handler(handler)
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
logger.remove_handler(handler)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
assert len(events_received) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handler_error_caught(self):
|
||||||
|
"""Test that handler errors are caught."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
def failing_handler(event):
|
||||||
|
raise ValueError("Handler error")
|
||||||
|
|
||||||
|
logger.add_handler(failing_handler)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
event = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
assert event is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerSanitization:
|
||||||
|
"""Tests for sensitive data sanitization."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sanitize_sensitive_keys(self):
|
||||||
|
"""Test sanitization of sensitive keys."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 30
|
||||||
|
mock_cfg.audit_include_sensitive = False
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
details={
|
||||||
|
"password": "secret123",
|
||||||
|
"api_key": "key123",
|
||||||
|
"token": "token123",
|
||||||
|
"normal_field": "visible",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details["password"] == "[REDACTED]"
|
||||||
|
assert event.details["api_key"] == "[REDACTED]"
|
||||||
|
assert event.details["token"] == "[REDACTED]"
|
||||||
|
assert event.details["normal_field"] == "visible"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sanitize_nested_dict(self):
|
||||||
|
"""Test sanitization of nested dictionaries."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 30
|
||||||
|
mock_cfg.audit_include_sensitive = False
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
details={
|
||||||
|
"config": {
|
||||||
|
"api_secret": "secret",
|
||||||
|
"name": "test",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details["config"]["api_secret"] == "[REDACTED]"
|
||||||
|
assert event.details["config"]["name"] == "test"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_include_sensitive_when_enabled(self):
|
||||||
|
"""Test sensitive data included when enabled."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 30
|
||||||
|
mock_cfg.audit_include_sensitive = True
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
details={"password": "secret123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details["password"] == "secret123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerRetention:
|
||||||
|
"""Tests for retention policy enforcement."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retention_removes_old_events(self):
|
||||||
|
"""Test that retention removes old events."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 7
|
||||||
|
mock_cfg.audit_include_sensitive = False
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
# Add an old event directly to persisted
|
||||||
|
from app.services.safety.models import AuditEvent
|
||||||
|
|
||||||
|
old_event = AuditEvent(
|
||||||
|
id="old-event",
|
||||||
|
event_type=AuditEventType.ACTION_REQUESTED,
|
||||||
|
timestamp=datetime.utcnow() - timedelta(days=10),
|
||||||
|
details={},
|
||||||
|
)
|
||||||
|
logger._persisted.append(old_event)
|
||||||
|
|
||||||
|
# Add a recent event
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
# Flush will trigger retention enforcement
|
||||||
|
await logger.flush()
|
||||||
|
|
||||||
|
# Old event should be removed
|
||||||
|
assert len(logger._persisted) == 1
|
||||||
|
assert logger._persisted[0].id != "old-event"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retention_keeps_recent_events(self):
|
||||||
|
"""Test that retention keeps recent events."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 7
|
||||||
|
mock_cfg.audit_include_sensitive = False
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
await logger.flush()
|
||||||
|
|
||||||
|
assert len(logger._persisted) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerSingleton:
|
||||||
|
"""Tests for singleton pattern."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_audit_logger_creates_instance(self):
|
||||||
|
"""Test get_audit_logger creates singleton."""
|
||||||
|
|
||||||
|
reset_audit_logger()
|
||||||
|
|
||||||
|
logger1 = await get_audit_logger()
|
||||||
|
logger2 = await get_audit_logger()
|
||||||
|
|
||||||
|
assert logger1 is logger2
|
||||||
|
|
||||||
|
await shutdown_audit_logger()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown_audit_logger(self):
|
||||||
|
"""Test shutdown_audit_logger stops and clears singleton."""
|
||||||
|
import app.services.safety.audit.logger as audit_module
|
||||||
|
|
||||||
|
reset_audit_logger()
|
||||||
|
|
||||||
|
_logger = await get_audit_logger()
|
||||||
|
await shutdown_audit_logger()
|
||||||
|
|
||||||
|
assert audit_module._audit_logger is None
|
||||||
|
|
||||||
|
def test_reset_audit_logger(self):
|
||||||
|
"""Test reset_audit_logger clears singleton."""
|
||||||
|
import app.services.safety.audit.logger as audit_module
|
||||||
|
|
||||||
|
audit_module._audit_logger = AuditLogger()
|
||||||
|
reset_audit_logger()
|
||||||
|
|
||||||
|
assert audit_module._audit_logger is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerPeriodicFlush:
|
||||||
|
"""Tests for periodic flush background task."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_periodic_flush_runs(self):
|
||||||
|
"""Test periodic flush runs and flushes events."""
|
||||||
|
logger = AuditLogger(flush_interval_seconds=0.1, enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
|
||||||
|
# Log an event
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
assert len(logger._buffer) == 1
|
||||||
|
|
||||||
|
# Wait for periodic flush
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
|
# Event should be flushed
|
||||||
|
assert len(logger._buffer) == 0
|
||||||
|
assert len(logger._persisted) == 1
|
||||||
|
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_periodic_flush_handles_errors(self):
|
||||||
|
"""Test periodic flush handles errors gracefully."""
|
||||||
|
logger = AuditLogger(flush_interval_seconds=0.1)
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
|
||||||
|
# Mock flush to raise an error
|
||||||
|
original_flush = logger.flush
|
||||||
|
|
||||||
|
async def failing_flush():
|
||||||
|
raise Exception("Flush error")
|
||||||
|
|
||||||
|
logger.flush = failing_flush
|
||||||
|
|
||||||
|
# Wait for flush attempt
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
|
# Should still be running
|
||||||
|
assert logger._running is True
|
||||||
|
|
||||||
|
logger.flush = original_flush
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerLogging:
|
||||||
|
"""Tests for standard logger output."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_warning_for_denied(self):
|
||||||
|
"""Test warning level for denied events."""
|
||||||
|
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
||||||
|
audit_logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await audit_logger.log(
|
||||||
|
AuditEventType.ACTION_DENIED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.warning.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_error_for_failed(self):
|
||||||
|
"""Test error level for failed events."""
|
||||||
|
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
||||||
|
audit_logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await audit_logger.log(
|
||||||
|
AuditEventType.ACTION_FAILED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.error.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_info_for_normal(self):
|
||||||
|
"""Test info level for normal events."""
|
||||||
|
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
||||||
|
audit_logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await audit_logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerEdgeCases:
|
||||||
|
"""Tests for edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_with_none_details(self):
|
||||||
|
"""Test logging with None details."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
details=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_action_id(self):
|
||||||
|
"""Test querying by action ID."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
action_id="action-1",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
action_id="action-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
events = await logger.query(action_id="action-1")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].action_id == "action-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_session_id(self):
|
||||||
|
"""Test querying by session ID."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
session_id="sess-1",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
session_id="sess-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
events = await logger.query(session_id="sess-1")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_includes_buffer_and_persisted(self):
|
||||||
|
"""Test query includes both buffer and persisted events."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
# Add event to buffer
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
# Flush to persisted
|
||||||
|
await logger.flush()
|
||||||
|
|
||||||
|
# Add another to buffer
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
# Query should return both
|
||||||
|
events = await logger.query()
|
||||||
|
|
||||||
|
assert len(events) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_nonexistent_handler(self):
|
||||||
|
"""Test removing handler that doesn't exist."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
def handler(event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
logger.remove_handler(handler)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_time_filter_excludes_events(self):
|
||||||
|
"""Test time filters exclude events correctly."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
# Query with future start time
|
||||||
|
future = datetime.utcnow() + timedelta(hours=1)
|
||||||
|
events = await logger.query(start_time=future)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_end_time_filter(self):
|
||||||
|
"""Test end time filter."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
# Query with past end time
|
||||||
|
past = datetime.utcnow() - timedelta(hours=1)
|
||||||
|
events = await logger.query(end_time=past)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
1136
backend/tests/services/safety/test_hitl.py
Normal file
1136
backend/tests/services/safety/test_hitl.py
Normal file
File diff suppressed because it is too large
Load Diff
874
backend/tests/services/safety/test_mcp_integration.py
Normal file
874
backend/tests/services/safety/test_mcp_integration.py
Normal file
@@ -0,0 +1,874 @@
|
|||||||
|
"""
|
||||||
|
Tests for MCP Safety Integration.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- MCPToolCall and MCPToolResult data structures
|
||||||
|
- MCPSafetyWrapper: tool registration, execution, safety checks
|
||||||
|
- Tool classification and action type mapping
|
||||||
|
- SafeToolExecutor context manager
|
||||||
|
- Factory function create_mcp_wrapper
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.exceptions import EmergencyStopError
|
||||||
|
from app.services.safety.mcp.integration import (
|
||||||
|
MCPSafetyWrapper,
|
||||||
|
MCPToolCall,
|
||||||
|
MCPToolResult,
|
||||||
|
SafeToolExecutor,
|
||||||
|
create_mcp_wrapper,
|
||||||
|
)
|
||||||
|
from app.services.safety.models import (
|
||||||
|
ActionType,
|
||||||
|
AutonomyLevel,
|
||||||
|
SafetyDecision,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolCall:
|
||||||
|
"""Tests for MCPToolCall dataclass."""
|
||||||
|
|
||||||
|
def test_tool_call_creation(self):
|
||||||
|
"""Test creating a tool call."""
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_read",
|
||||||
|
arguments={"path": "/tmp/test.txt"}, # noqa: S108
|
||||||
|
server_name="file-server",
|
||||||
|
project_id="proj-1",
|
||||||
|
context={"session_id": "sess-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert call.tool_name == "file_read"
|
||||||
|
assert call.arguments == {"path": "/tmp/test.txt"} # noqa: S108
|
||||||
|
assert call.server_name == "file-server"
|
||||||
|
assert call.project_id == "proj-1"
|
||||||
|
assert call.context == {"session_id": "sess-1"}
|
||||||
|
|
||||||
|
def test_tool_call_defaults(self):
|
||||||
|
"""Test tool call default values."""
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="test",
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert call.server_name is None
|
||||||
|
assert call.project_id is None
|
||||||
|
assert call.context == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolResult:
|
||||||
|
"""Tests for MCPToolResult dataclass."""
|
||||||
|
|
||||||
|
def test_tool_result_success(self):
|
||||||
|
"""Test creating a successful result."""
|
||||||
|
result = MCPToolResult(
|
||||||
|
success=True,
|
||||||
|
result={"data": "test"},
|
||||||
|
safety_decision=SafetyDecision.ALLOW,
|
||||||
|
execution_time_ms=50.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == {"data": "test"}
|
||||||
|
assert result.error is None
|
||||||
|
assert result.safety_decision == SafetyDecision.ALLOW
|
||||||
|
assert result.execution_time_ms == 50.0
|
||||||
|
|
||||||
|
def test_tool_result_failure(self):
|
||||||
|
"""Test creating a failed result."""
|
||||||
|
result = MCPToolResult(
|
||||||
|
success=False,
|
||||||
|
error="Permission denied",
|
||||||
|
safety_decision=SafetyDecision.DENY,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error == "Permission denied"
|
||||||
|
assert result.result is None
|
||||||
|
|
||||||
|
def test_tool_result_with_ids(self):
|
||||||
|
"""Test result with approval and checkpoint IDs."""
|
||||||
|
result = MCPToolResult(
|
||||||
|
success=True,
|
||||||
|
approval_id="approval-123",
|
||||||
|
checkpoint_id="checkpoint-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.approval_id == "approval-123"
|
||||||
|
assert result.checkpoint_id == "checkpoint-456"
|
||||||
|
|
||||||
|
def test_tool_result_defaults(self):
|
||||||
|
"""Test result default values."""
|
||||||
|
result = MCPToolResult(success=True)
|
||||||
|
|
||||||
|
assert result.result is None
|
||||||
|
assert result.error is None
|
||||||
|
assert result.safety_decision == SafetyDecision.ALLOW
|
||||||
|
assert result.execution_time_ms == 0.0
|
||||||
|
assert result.approval_id is None
|
||||||
|
assert result.checkpoint_id is None
|
||||||
|
assert result.metadata == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPSafetyWrapperClassification:
|
||||||
|
"""Tests for tool classification."""
|
||||||
|
|
||||||
|
def test_classify_file_read(self):
|
||||||
|
"""Test classifying file read tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("file_read") == ActionType.FILE_READ
|
||||||
|
assert wrapper._classify_tool("get_file") == ActionType.FILE_READ
|
||||||
|
assert wrapper._classify_tool("list_files") == ActionType.FILE_READ
|
||||||
|
assert wrapper._classify_tool("search_file") == ActionType.FILE_READ
|
||||||
|
|
||||||
|
def test_classify_file_write(self):
|
||||||
|
"""Test classifying file write tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("file_write") == ActionType.FILE_WRITE
|
||||||
|
assert wrapper._classify_tool("create_file") == ActionType.FILE_WRITE
|
||||||
|
assert wrapper._classify_tool("update_file") == ActionType.FILE_WRITE
|
||||||
|
|
||||||
|
def test_classify_file_delete(self):
|
||||||
|
"""Test classifying file delete tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("file_delete") == ActionType.FILE_DELETE
|
||||||
|
assert wrapper._classify_tool("remove_file") == ActionType.FILE_DELETE
|
||||||
|
|
||||||
|
def test_classify_database_read(self):
|
||||||
|
"""Test classifying database read tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("database_query") == ActionType.DATABASE_QUERY
|
||||||
|
assert wrapper._classify_tool("db_read") == ActionType.DATABASE_QUERY
|
||||||
|
assert wrapper._classify_tool("query_database") == ActionType.DATABASE_QUERY
|
||||||
|
|
||||||
|
def test_classify_database_mutate(self):
|
||||||
|
"""Test classifying database mutate tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("database_write") == ActionType.DATABASE_MUTATE
|
||||||
|
assert wrapper._classify_tool("db_update") == ActionType.DATABASE_MUTATE
|
||||||
|
assert wrapper._classify_tool("database_delete") == ActionType.DATABASE_MUTATE
|
||||||
|
|
||||||
|
def test_classify_shell_command(self):
|
||||||
|
"""Test classifying shell command tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("shell_execute") == ActionType.SHELL_COMMAND
|
||||||
|
assert wrapper._classify_tool("exec_command") == ActionType.SHELL_COMMAND
|
||||||
|
assert wrapper._classify_tool("bash_run") == ActionType.SHELL_COMMAND
|
||||||
|
|
||||||
|
def test_classify_git_operation(self):
|
||||||
|
"""Test classifying git tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("git_commit") == ActionType.GIT_OPERATION
|
||||||
|
assert wrapper._classify_tool("git_push") == ActionType.GIT_OPERATION
|
||||||
|
assert wrapper._classify_tool("git_status") == ActionType.GIT_OPERATION
|
||||||
|
|
||||||
|
def test_classify_network_request(self):
|
||||||
|
"""Test classifying network tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("http_get") == ActionType.NETWORK_REQUEST
|
||||||
|
assert wrapper._classify_tool("fetch_url") == ActionType.NETWORK_REQUEST
|
||||||
|
assert wrapper._classify_tool("api_request") == ActionType.NETWORK_REQUEST
|
||||||
|
|
||||||
|
def test_classify_llm_call(self):
|
||||||
|
"""Test classifying LLM tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("llm_generate") == ActionType.LLM_CALL
|
||||||
|
assert wrapper._classify_tool("ai_complete") == ActionType.LLM_CALL
|
||||||
|
assert wrapper._classify_tool("claude_chat") == ActionType.LLM_CALL
|
||||||
|
|
||||||
|
def test_classify_default(self):
|
||||||
|
"""Test default classification for unknown tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("unknown_tool") == ActionType.TOOL_CALL
|
||||||
|
assert wrapper._classify_tool("custom_action") == ActionType.TOOL_CALL
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPSafetyWrapperToolHandlers:
|
||||||
|
"""Tests for tool handler registration."""
|
||||||
|
|
||||||
|
def test_register_tool_handler(self):
|
||||||
|
"""Test registering a tool handler."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
def handler(path: str) -> str:
|
||||||
|
return f"Read: {path}"
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("file_read", handler)
|
||||||
|
|
||||||
|
assert "file_read" in wrapper._tool_handlers
|
||||||
|
assert wrapper._tool_handlers["file_read"] is handler
|
||||||
|
|
||||||
|
def test_register_multiple_handlers(self):
|
||||||
|
"""Test registering multiple handlers."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool1", lambda: None)
|
||||||
|
wrapper.register_tool_handler("tool2", lambda: None)
|
||||||
|
wrapper.register_tool_handler("tool3", lambda: None)
|
||||||
|
|
||||||
|
assert len(wrapper._tool_handlers) == 3
|
||||||
|
|
||||||
|
def test_overwrite_handler(self):
|
||||||
|
"""Test overwriting a handler."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
handler1 = lambda: "first" # noqa: E731
|
||||||
|
handler2 = lambda: "second" # noqa: E731
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", handler1)
|
||||||
|
wrapper.register_tool_handler("tool", handler2)
|
||||||
|
|
||||||
|
assert wrapper._tool_handlers["tool"] is handler2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPSafetyWrapperExecution:
|
||||||
|
"""Tests for tool execution."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_guardian(self):
|
||||||
|
"""Create a mock SafetyGuardian."""
|
||||||
|
guardian = AsyncMock()
|
||||||
|
guardian.validate = AsyncMock()
|
||||||
|
return guardian
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_emergency(self):
|
||||||
|
"""Create a mock EmergencyControls."""
|
||||||
|
emergency = AsyncMock()
|
||||||
|
emergency.check_allowed = AsyncMock()
|
||||||
|
return emergency
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_allowed(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing an allowed tool call."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler(path: str) -> dict:
|
||||||
|
return {"content": f"Data from {path}"}
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("file_read", handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_read",
|
||||||
|
arguments={"path": "/test.txt"},
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == {"content": "Data from /test.txt"}
|
||||||
|
assert result.safety_decision == SafetyDecision.ALLOW
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_denied(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing a denied tool call."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.DENY,
|
||||||
|
reasons=["Permission denied", "Rate limit exceeded"],
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_write",
|
||||||
|
arguments={"path": "/etc/passwd"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Permission denied" in result.error
|
||||||
|
assert "Rate limit exceeded" in result.error
|
||||||
|
assert result.safety_decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_requires_approval(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing a tool that requires approval."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||||
|
reasons=["Destructive operation requires approval"],
|
||||||
|
approval_id="approval-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_delete",
|
||||||
|
arguments={"path": "/important.txt"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.safety_decision == SafetyDecision.REQUIRE_APPROVAL
|
||||||
|
assert result.approval_id == "approval-123"
|
||||||
|
assert "requires human approval" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_emergency_stop(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test execution blocked by emergency stop."""
|
||||||
|
mock_emergency.check_allowed.side_effect = EmergencyStopError(
|
||||||
|
"Emergency stop active"
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_write",
|
||||||
|
arguments={"path": "/test.txt"},
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.safety_decision == SafetyDecision.DENY
|
||||||
|
assert result.metadata.get("emergency_stop") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_bypass_safety(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing with safety bypass."""
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler(data: str) -> str:
|
||||||
|
return f"Processed: {data}"
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("custom_tool", handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="custom_tool",
|
||||||
|
arguments={"data": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1", bypass_safety=True)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == "Processed: test"
|
||||||
|
# Guardian should not be called when bypassing
|
||||||
|
mock_guardian.validate.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_no_handler(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing a tool with no registered handler."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="unregistered_tool",
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "No handler registered" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_handler_exception(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test handling exceptions from tool handler."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def failing_handler() -> None:
|
||||||
|
raise ValueError("Handler failed!")
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("failing_tool", failing_handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="failing_tool",
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Handler failed!" in result.error
|
||||||
|
# Decision is still ALLOW because the safety check passed
|
||||||
|
assert result.safety_decision == SafetyDecision.ALLOW
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_sync_handler(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing a synchronous handler."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_handler(value: int) -> int:
|
||||||
|
return value * 2
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("sync_tool", sync_handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="sync_tool",
|
||||||
|
arguments={"value": 21},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == 42
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildActionRequest:
|
||||||
|
"""Tests for _build_action_request."""
|
||||||
|
|
||||||
|
def test_build_action_request_basic(self):
|
||||||
|
"""Test building a basic action request."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_read",
|
||||||
|
arguments={"path": "/test.txt"},
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
action = wrapper._build_action_request(call, "agent-1", AutonomyLevel.MILESTONE)
|
||||||
|
|
||||||
|
assert action.action_type == ActionType.FILE_READ
|
||||||
|
assert action.tool_name == "file_read"
|
||||||
|
assert action.arguments == {"path": "/test.txt"}
|
||||||
|
assert action.resource == "/test.txt"
|
||||||
|
assert action.metadata.agent_id == "agent-1"
|
||||||
|
assert action.metadata.project_id == "proj-1"
|
||||||
|
assert action.metadata.autonomy_level == AutonomyLevel.MILESTONE
|
||||||
|
|
||||||
|
def test_build_action_request_with_context(self):
|
||||||
|
"""Test building action request with session context."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="database_query",
|
||||||
|
arguments={"resource": "users", "query": "SELECT *"},
|
||||||
|
context={"session_id": "sess-123"},
|
||||||
|
project_id="proj-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
action = wrapper._build_action_request(
|
||||||
|
call, "agent-2", AutonomyLevel.AUTONOMOUS
|
||||||
|
)
|
||||||
|
|
||||||
|
assert action.resource == "users"
|
||||||
|
assert action.metadata.session_id == "sess-123"
|
||||||
|
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||||
|
|
||||||
|
def test_build_action_request_no_resource(self):
|
||||||
|
"""Test building action request without resource."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="llm_generate",
|
||||||
|
arguments={"prompt": "Hello"},
|
||||||
|
)
|
||||||
|
|
||||||
|
action = wrapper._build_action_request(
|
||||||
|
call, "agent-1", AutonomyLevel.FULL_CONTROL
|
||||||
|
)
|
||||||
|
|
||||||
|
assert action.resource is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestElapsedTime:
|
||||||
|
"""Tests for _elapsed_ms helper."""
|
||||||
|
|
||||||
|
def test_elapsed_ms(self):
|
||||||
|
"""Test calculating elapsed time."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
start = datetime.utcnow() - timedelta(milliseconds=100)
|
||||||
|
elapsed = wrapper._elapsed_ms(start)
|
||||||
|
|
||||||
|
# Should be at least 100ms, but allow some tolerance
|
||||||
|
assert elapsed >= 99
|
||||||
|
assert elapsed < 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeToolExecutor:
|
||||||
|
"""Tests for SafeToolExecutor context manager."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_executor_execute(self):
|
||||||
|
"""Test executing within context manager."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler() -> str:
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("test_tool", handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="test_tool", arguments={})
|
||||||
|
|
||||||
|
async with SafeToolExecutor(wrapper, call, "agent-1") as executor:
|
||||||
|
result = await executor.execute()
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == "success"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_executor_result_property(self):
|
||||||
|
"""Test accessing result via property."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: "data")
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="tool", arguments={})
|
||||||
|
executor = SafeToolExecutor(wrapper, call, "agent-1")
|
||||||
|
|
||||||
|
# Before execution
|
||||||
|
assert executor.result is None
|
||||||
|
|
||||||
|
async with executor:
|
||||||
|
await executor.execute()
|
||||||
|
|
||||||
|
# After execution
|
||||||
|
assert executor.result is not None
|
||||||
|
assert executor.result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_executor_with_autonomy_level(self):
|
||||||
|
"""Test executor with custom autonomy level."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: None)
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="tool", arguments={})
|
||||||
|
|
||||||
|
async with SafeToolExecutor(
|
||||||
|
wrapper, call, "agent-1", AutonomyLevel.AUTONOMOUS
|
||||||
|
) as executor:
|
||||||
|
await executor.execute()
|
||||||
|
|
||||||
|
# Check that guardian was called with correct autonomy level
|
||||||
|
mock_guardian.validate.assert_called_once()
|
||||||
|
action = mock_guardian.validate.call_args[0][0]
|
||||||
|
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateMCPWrapper:
|
||||||
|
"""Tests for create_mcp_wrapper factory function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_wrapper_with_guardian(self):
|
||||||
|
"""Test creating wrapper with provided guardian."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||||
|
) as mock_get_emergency:
|
||||||
|
mock_get_emergency.return_value = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = await create_mcp_wrapper(guardian=mock_guardian)
|
||||||
|
|
||||||
|
assert wrapper._guardian is mock_guardian
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_wrapper_default_guardian(self):
|
||||||
|
"""Test creating wrapper with default guardian."""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.safety.mcp.integration.get_safety_guardian"
|
||||||
|
) as mock_get_guardian,
|
||||||
|
patch(
|
||||||
|
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||||
|
) as mock_get_emergency,
|
||||||
|
):
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_get_guardian.return_value = mock_guardian
|
||||||
|
mock_get_emergency.return_value = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = await create_mcp_wrapper()
|
||||||
|
|
||||||
|
assert wrapper._guardian is mock_guardian
|
||||||
|
mock_get_guardian.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLazyGetters:
|
||||||
|
"""Tests for lazy getter methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_guardian_lazy(self):
|
||||||
|
"""Test lazy guardian initialization."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.safety.mcp.integration.get_safety_guardian"
|
||||||
|
) as mock_get:
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_get.return_value = mock_guardian
|
||||||
|
|
||||||
|
guardian = await wrapper._get_guardian()
|
||||||
|
|
||||||
|
assert guardian is mock_guardian
|
||||||
|
mock_get.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_guardian_cached(self):
|
||||||
|
"""Test guardian is cached after first access."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
wrapper = MCPSafetyWrapper(guardian=mock_guardian)
|
||||||
|
|
||||||
|
guardian = await wrapper._get_guardian()
|
||||||
|
|
||||||
|
assert guardian is mock_guardian
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_emergency_controls_lazy(self):
|
||||||
|
"""Test lazy emergency controls initialization."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||||
|
) as mock_get:
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
mock_get.return_value = mock_emergency
|
||||||
|
|
||||||
|
emergency = await wrapper._get_emergency_controls()
|
||||||
|
|
||||||
|
assert emergency is mock_emergency
|
||||||
|
mock_get.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_emergency_controls_cached(self):
|
||||||
|
"""Test emergency controls is cached after first access."""
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
wrapper = MCPSafetyWrapper(emergency_controls=mock_emergency)
|
||||||
|
|
||||||
|
emergency = await wrapper._get_emergency_controls()
|
||||||
|
|
||||||
|
assert emergency is mock_emergency
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Tests for edge cases and error handling."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_safety_error(self):
|
||||||
|
"""Test handling SafetyError from guardian."""
|
||||||
|
from app.services.safety.exceptions import SafetyError
|
||||||
|
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.side_effect = SafetyError("Internal safety error")
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="test", arguments={})
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Internal safety error" in result.error
|
||||||
|
assert result.safety_decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_checkpoint_id(self):
|
||||||
|
"""Test that checkpoint_id is propagated to result."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id="checkpoint-abc",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: "result")
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="tool", arguments={})
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.checkpoint_id == "checkpoint-abc"
|
||||||
|
|
||||||
|
def test_destructive_tools_constant(self):
|
||||||
|
"""Test DESTRUCTIVE_TOOLS class constant."""
|
||||||
|
assert "file_write" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||||
|
assert "file_delete" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||||
|
assert "shell_execute" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||||
|
assert "git_push" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||||
|
|
||||||
|
def test_read_only_tools_constant(self):
|
||||||
|
"""Test READ_ONLY_TOOLS class constant."""
|
||||||
|
assert "file_read" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||||
|
assert "database_query" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||||
|
assert "git_status" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||||
|
assert "search" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scope_with_project_id(self):
|
||||||
|
"""Test that scope is set correctly with project_id."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: None)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="tool",
|
||||||
|
arguments={},
|
||||||
|
project_id="proj-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
# Verify emergency check was called with project scope
|
||||||
|
mock_emergency.check_allowed.assert_called_once()
|
||||||
|
call_kwargs = mock_emergency.check_allowed.call_args
|
||||||
|
assert "project:proj-123" in str(call_kwargs)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scope_without_project_id(self):
|
||||||
|
"""Test that scope falls back to agent when no project_id."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: None)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="tool",
|
||||||
|
arguments={},
|
||||||
|
# No project_id
|
||||||
|
)
|
||||||
|
|
||||||
|
await wrapper.execute(call, "agent-555")
|
||||||
|
|
||||||
|
# Verify emergency check was called with agent scope
|
||||||
|
mock_emergency.check_allowed.assert_called_once()
|
||||||
|
call_kwargs = mock_emergency.check_allowed.call_args
|
||||||
|
assert "agent:agent-555" in str(call_kwargs)
|
||||||
747
backend/tests/services/safety/test_metrics.py
Normal file
747
backend/tests/services/safety/test_metrics.py
Normal file
@@ -0,0 +1,747 @@
|
|||||||
|
"""
|
||||||
|
Tests for Safety Metrics Collector.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- MetricType, MetricValue, HistogramBucket data structures
|
||||||
|
- SafetyMetrics counters, gauges, histograms
|
||||||
|
- Prometheus format export
|
||||||
|
- Summary and reset operations
|
||||||
|
- Singleton pattern and convenience functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.metrics.collector import (
|
||||||
|
HistogramBucket,
|
||||||
|
MetricType,
|
||||||
|
MetricValue,
|
||||||
|
SafetyMetrics,
|
||||||
|
get_safety_metrics,
|
||||||
|
record_mcp_call,
|
||||||
|
record_validation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricType:
|
||||||
|
"""Tests for MetricType enum."""
|
||||||
|
|
||||||
|
def test_metric_types_exist(self):
|
||||||
|
"""Test all metric types are defined."""
|
||||||
|
assert MetricType.COUNTER == "counter"
|
||||||
|
assert MetricType.GAUGE == "gauge"
|
||||||
|
assert MetricType.HISTOGRAM == "histogram"
|
||||||
|
|
||||||
|
def test_metric_type_is_string(self):
|
||||||
|
"""Test MetricType values are strings."""
|
||||||
|
assert isinstance(MetricType.COUNTER.value, str)
|
||||||
|
assert isinstance(MetricType.GAUGE.value, str)
|
||||||
|
assert isinstance(MetricType.HISTOGRAM.value, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricValue:
|
||||||
|
"""Tests for MetricValue dataclass."""
|
||||||
|
|
||||||
|
def test_metric_value_creation(self):
|
||||||
|
"""Test creating a metric value."""
|
||||||
|
mv = MetricValue(
|
||||||
|
name="test_metric",
|
||||||
|
metric_type=MetricType.COUNTER,
|
||||||
|
value=42.0,
|
||||||
|
labels={"env": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mv.name == "test_metric"
|
||||||
|
assert mv.metric_type == MetricType.COUNTER
|
||||||
|
assert mv.value == 42.0
|
||||||
|
assert mv.labels == {"env": "test"}
|
||||||
|
assert mv.timestamp is not None
|
||||||
|
|
||||||
|
def test_metric_value_defaults(self):
|
||||||
|
"""Test metric value default values."""
|
||||||
|
mv = MetricValue(
|
||||||
|
name="test",
|
||||||
|
metric_type=MetricType.GAUGE,
|
||||||
|
value=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mv.labels == {}
|
||||||
|
assert mv.timestamp is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestHistogramBucket:
|
||||||
|
"""Tests for HistogramBucket dataclass."""
|
||||||
|
|
||||||
|
def test_histogram_bucket_creation(self):
|
||||||
|
"""Test creating a histogram bucket."""
|
||||||
|
bucket = HistogramBucket(le=0.5, count=10)
|
||||||
|
|
||||||
|
assert bucket.le == 0.5
|
||||||
|
assert bucket.count == 10
|
||||||
|
|
||||||
|
def test_histogram_bucket_defaults(self):
|
||||||
|
"""Test histogram bucket default count."""
|
||||||
|
bucket = HistogramBucket(le=1.0)
|
||||||
|
|
||||||
|
assert bucket.le == 1.0
|
||||||
|
assert bucket.count == 0
|
||||||
|
|
||||||
|
def test_histogram_bucket_infinity(self):
|
||||||
|
"""Test histogram bucket with infinity."""
|
||||||
|
bucket = HistogramBucket(le=float("inf"))
|
||||||
|
|
||||||
|
assert bucket.le == float("inf")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsCounters:
|
||||||
|
"""Tests for SafetyMetrics counter methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def metrics(self):
|
||||||
|
"""Create fresh metrics instance."""
|
||||||
|
return SafetyMetrics()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_validations(self, metrics):
|
||||||
|
"""Test incrementing validations counter."""
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
await metrics.inc_validations("deny", agent_id="agent-1")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 3
|
||||||
|
assert summary["denied_validations"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_approvals_requested(self, metrics):
|
||||||
|
"""Test incrementing approval requests counter."""
|
||||||
|
await metrics.inc_approvals_requested("normal")
|
||||||
|
await metrics.inc_approvals_requested("urgent")
|
||||||
|
await metrics.inc_approvals_requested() # default
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["approval_requests"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_approvals_granted(self, metrics):
|
||||||
|
"""Test incrementing approvals granted counter."""
|
||||||
|
await metrics.inc_approvals_granted()
|
||||||
|
await metrics.inc_approvals_granted()
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["approvals_granted"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_approvals_denied(self, metrics):
|
||||||
|
"""Test incrementing approvals denied counter."""
|
||||||
|
await metrics.inc_approvals_denied("timeout")
|
||||||
|
await metrics.inc_approvals_denied("policy")
|
||||||
|
await metrics.inc_approvals_denied() # default manual
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["approvals_denied"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_rate_limit_exceeded(self, metrics):
|
||||||
|
"""Test incrementing rate limit exceeded counter."""
|
||||||
|
await metrics.inc_rate_limit_exceeded("requests_per_minute")
|
||||||
|
await metrics.inc_rate_limit_exceeded("tokens_per_hour")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["rate_limit_hits"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_budget_exceeded(self, metrics):
|
||||||
|
"""Test incrementing budget exceeded counter."""
|
||||||
|
await metrics.inc_budget_exceeded("daily_cost")
|
||||||
|
await metrics.inc_budget_exceeded("monthly_tokens")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["budget_exceeded"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_loops_detected(self, metrics):
|
||||||
|
"""Test incrementing loops detected counter."""
|
||||||
|
await metrics.inc_loops_detected("repetition")
|
||||||
|
await metrics.inc_loops_detected("pattern")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["loops_detected"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_emergency_events(self, metrics):
|
||||||
|
"""Test incrementing emergency events counter."""
|
||||||
|
await metrics.inc_emergency_events("pause", "project-1")
|
||||||
|
await metrics.inc_emergency_events("stop", "agent-2")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["emergency_events"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_content_filtered(self, metrics):
|
||||||
|
"""Test incrementing content filtered counter."""
|
||||||
|
await metrics.inc_content_filtered("profanity", "blocked")
|
||||||
|
await metrics.inc_content_filtered("pii", "redacted")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["content_filtered"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_checkpoints_created(self, metrics):
|
||||||
|
"""Test incrementing checkpoints created counter."""
|
||||||
|
await metrics.inc_checkpoints_created()
|
||||||
|
await metrics.inc_checkpoints_created()
|
||||||
|
await metrics.inc_checkpoints_created()
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["checkpoints_created"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_rollbacks_executed(self, metrics):
|
||||||
|
"""Test incrementing rollbacks executed counter."""
|
||||||
|
await metrics.inc_rollbacks_executed(success=True)
|
||||||
|
await metrics.inc_rollbacks_executed(success=False)
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["rollbacks_executed"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_mcp_calls(self, metrics):
|
||||||
|
"""Test incrementing MCP calls counter."""
|
||||||
|
await metrics.inc_mcp_calls("search_knowledge", success=True)
|
||||||
|
await metrics.inc_mcp_calls("run_code", success=False)
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["mcp_calls"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsGauges:
|
||||||
|
"""Tests for SafetyMetrics gauge methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def metrics(self):
|
||||||
|
"""Create fresh metrics instance."""
|
||||||
|
return SafetyMetrics()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_budget_remaining(self, metrics):
|
||||||
|
"""Test setting budget remaining gauge."""
|
||||||
|
await metrics.set_budget_remaining("project-1", "daily_cost", 50.0)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
gauge_metrics = [m for m in all_metrics if m.name == "safety_budget_remaining"]
|
||||||
|
assert len(gauge_metrics) == 1
|
||||||
|
assert gauge_metrics[0].value == 50.0
|
||||||
|
assert gauge_metrics[0].labels["scope"] == "project-1"
|
||||||
|
assert gauge_metrics[0].labels["budget_type"] == "daily_cost"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_rate_limit_remaining(self, metrics):
|
||||||
|
"""Test setting rate limit remaining gauge."""
|
||||||
|
await metrics.set_rate_limit_remaining("agent-1", "requests_per_minute", 45)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
gauge_metrics = [
|
||||||
|
m for m in all_metrics if m.name == "safety_rate_limit_remaining"
|
||||||
|
]
|
||||||
|
assert len(gauge_metrics) == 1
|
||||||
|
assert gauge_metrics[0].value == 45.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_pending_approvals(self, metrics):
|
||||||
|
"""Test setting pending approvals gauge."""
|
||||||
|
await metrics.set_pending_approvals(5)
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["pending_approvals"] == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_active_checkpoints(self, metrics):
|
||||||
|
"""Test setting active checkpoints gauge."""
|
||||||
|
await metrics.set_active_checkpoints(3)
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["active_checkpoints"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_emergency_state(self, metrics):
|
||||||
|
"""Test setting emergency state gauge."""
|
||||||
|
await metrics.set_emergency_state("project-1", "normal")
|
||||||
|
await metrics.set_emergency_state("project-2", "paused")
|
||||||
|
await metrics.set_emergency_state("project-3", "stopped")
|
||||||
|
await metrics.set_emergency_state("project-4", "unknown")
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
state_metrics = [m for m in all_metrics if m.name == "safety_emergency_state"]
|
||||||
|
assert len(state_metrics) == 4
|
||||||
|
|
||||||
|
# Check state values
|
||||||
|
values_by_scope = {m.labels["scope"]: m.value for m in state_metrics}
|
||||||
|
assert values_by_scope["project-1"] == 0.0 # normal
|
||||||
|
assert values_by_scope["project-2"] == 1.0 # paused
|
||||||
|
assert values_by_scope["project-3"] == 2.0 # stopped
|
||||||
|
assert values_by_scope["project-4"] == -1.0 # unknown
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsHistograms:
|
||||||
|
"""Tests for SafetyMetrics histogram methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def metrics(self):
|
||||||
|
"""Create fresh metrics instance."""
|
||||||
|
return SafetyMetrics()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_observe_validation_latency(self, metrics):
|
||||||
|
"""Test observing validation latency."""
|
||||||
|
await metrics.observe_validation_latency(0.05)
|
||||||
|
await metrics.observe_validation_latency(0.15)
|
||||||
|
await metrics.observe_validation_latency(0.5)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
|
||||||
|
count_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "validation_latency_seconds_count"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert count_metric is not None
|
||||||
|
assert count_metric.value == 3.0
|
||||||
|
|
||||||
|
sum_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "validation_latency_seconds_sum"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert sum_metric is not None
|
||||||
|
assert abs(sum_metric.value - 0.7) < 0.001
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_observe_approval_latency(self, metrics):
|
||||||
|
"""Test observing approval latency."""
|
||||||
|
await metrics.observe_approval_latency(1.5)
|
||||||
|
await metrics.observe_approval_latency(3.0)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
|
||||||
|
count_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "approval_latency_seconds_count"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert count_metric is not None
|
||||||
|
assert count_metric.value == 2.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_observe_mcp_execution_latency(self, metrics):
|
||||||
|
"""Test observing MCP execution latency."""
|
||||||
|
await metrics.observe_mcp_execution_latency(0.02)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
|
||||||
|
count_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "mcp_execution_latency_seconds_count"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert count_metric is not None
|
||||||
|
assert count_metric.value == 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_histogram_bucket_updates(self, metrics):
|
||||||
|
"""Test that histogram buckets are updated correctly."""
|
||||||
|
# Add values to test bucket distribution
|
||||||
|
await metrics.observe_validation_latency(0.005) # <= 0.01
|
||||||
|
await metrics.observe_validation_latency(0.03) # <= 0.05
|
||||||
|
await metrics.observe_validation_latency(0.07) # <= 0.1
|
||||||
|
await metrics.observe_validation_latency(15.0) # <= inf
|
||||||
|
|
||||||
|
prometheus = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
# Check that bucket counts are in output
|
||||||
|
assert "validation_latency_seconds_bucket" in prometheus
|
||||||
|
assert "le=" in prometheus
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsExport:
|
||||||
|
"""Tests for SafetyMetrics export methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def metrics(self):
|
||||||
|
"""Create fresh metrics instance with some data."""
|
||||||
|
m = SafetyMetrics()
|
||||||
|
|
||||||
|
# Add some counters
|
||||||
|
await m.inc_validations("allow")
|
||||||
|
await m.inc_validations("deny", agent_id="agent-1")
|
||||||
|
|
||||||
|
# Add some gauges
|
||||||
|
await m.set_pending_approvals(3)
|
||||||
|
await m.set_budget_remaining("proj-1", "daily", 100.0)
|
||||||
|
|
||||||
|
# Add some histogram values
|
||||||
|
await m.observe_validation_latency(0.1)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all_metrics(self, metrics):
|
||||||
|
"""Test getting all metrics."""
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
|
||||||
|
assert len(all_metrics) > 0
|
||||||
|
assert all(isinstance(m, MetricValue) for m in all_metrics)
|
||||||
|
|
||||||
|
# Check we have different types
|
||||||
|
types = {m.metric_type for m in all_metrics}
|
||||||
|
assert MetricType.COUNTER in types
|
||||||
|
assert MetricType.GAUGE in types
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_prometheus_format(self, metrics):
|
||||||
|
"""Test Prometheus format export."""
|
||||||
|
output = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
assert isinstance(output, str)
|
||||||
|
assert "# TYPE" in output
|
||||||
|
assert "counter" in output
|
||||||
|
assert "gauge" in output
|
||||||
|
assert "safety_validations_total" in output
|
||||||
|
assert "safety_pending_approvals" in output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prometheus_format_with_labels(self, metrics):
|
||||||
|
"""Test Prometheus format includes labels correctly."""
|
||||||
|
output = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
# Counter with labels
|
||||||
|
assert "decision=allow" in output or "decision=deny" in output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prometheus_format_histogram_buckets(self, metrics):
|
||||||
|
"""Test Prometheus format includes histogram buckets."""
|
||||||
|
output = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
assert "histogram" in output
|
||||||
|
assert "_bucket" in output
|
||||||
|
assert "le=" in output
|
||||||
|
assert "+Inf" in output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_summary(self, metrics):
|
||||||
|
"""Test getting summary."""
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
|
||||||
|
assert "total_validations" in summary
|
||||||
|
assert "denied_validations" in summary
|
||||||
|
assert "approval_requests" in summary
|
||||||
|
assert "pending_approvals" in summary
|
||||||
|
assert "active_checkpoints" in summary
|
||||||
|
|
||||||
|
assert summary["total_validations"] == 2
|
||||||
|
assert summary["denied_validations"] == 1
|
||||||
|
assert summary["pending_approvals"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_summary_empty_counters(self):
|
||||||
|
"""Test summary with no data."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
|
||||||
|
assert summary["total_validations"] == 0
|
||||||
|
assert summary["denied_validations"] == 0
|
||||||
|
assert summary["pending_approvals"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsReset:
|
||||||
|
"""Tests for SafetyMetrics reset."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_clears_counters(self):
|
||||||
|
"""Test reset clears all counters."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
await metrics.inc_approvals_granted()
|
||||||
|
await metrics.set_pending_approvals(5)
|
||||||
|
await metrics.observe_validation_latency(0.1)
|
||||||
|
|
||||||
|
await metrics.reset()
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 0
|
||||||
|
assert summary["approvals_granted"] == 0
|
||||||
|
assert summary["pending_approvals"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_reinitializes_histogram_buckets(self):
|
||||||
|
"""Test reset reinitializes histogram buckets."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.observe_validation_latency(0.1)
|
||||||
|
await metrics.reset()
|
||||||
|
|
||||||
|
# After reset, histogram buckets should be reinitialized
|
||||||
|
prometheus = await metrics.get_prometheus_format()
|
||||||
|
assert "validation_latency_seconds" in prometheus
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseLabels:
|
||||||
|
"""Tests for _parse_labels helper method."""
|
||||||
|
|
||||||
|
def test_parse_empty_labels(self):
|
||||||
|
"""Test parsing empty labels string."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("")
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_parse_single_label(self):
|
||||||
|
"""Test parsing single label."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("key=value")
|
||||||
|
assert result == {"key": "value"}
|
||||||
|
|
||||||
|
def test_parse_multiple_labels(self):
|
||||||
|
"""Test parsing multiple labels."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("a=1,b=2,c=3")
|
||||||
|
assert result == {"a": "1", "b": "2", "c": "3"}
|
||||||
|
|
||||||
|
def test_parse_labels_with_spaces(self):
|
||||||
|
"""Test parsing labels with spaces."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels(" key = value , foo = bar ")
|
||||||
|
assert result == {"key": "value", "foo": "bar"}
|
||||||
|
|
||||||
|
def test_parse_labels_with_equals_in_value(self):
|
||||||
|
"""Test parsing labels with = in value."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("query=a=b")
|
||||||
|
assert result == {"query": "a=b"}
|
||||||
|
|
||||||
|
def test_parse_invalid_label(self):
|
||||||
|
"""Test parsing invalid label without equals."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("no_equals")
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestHistogramBucketInit:
|
||||||
|
"""Tests for histogram bucket initialization."""
|
||||||
|
|
||||||
|
def test_histogram_buckets_initialized(self):
|
||||||
|
"""Test that histogram buckets are initialized."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
assert "validation_latency_seconds" in metrics._histogram_buckets
|
||||||
|
assert "approval_latency_seconds" in metrics._histogram_buckets
|
||||||
|
assert "mcp_execution_latency_seconds" in metrics._histogram_buckets
|
||||||
|
|
||||||
|
def test_histogram_buckets_have_correct_values(self):
|
||||||
|
"""Test histogram buckets have correct boundary values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
buckets = metrics._histogram_buckets["validation_latency_seconds"]
|
||||||
|
|
||||||
|
# Check first few and last bucket
|
||||||
|
assert buckets[0].le == 0.01
|
||||||
|
assert buckets[1].le == 0.05
|
||||||
|
assert buckets[-1].le == float("inf")
|
||||||
|
|
||||||
|
# Check all have zero initial count
|
||||||
|
assert all(b.count == 0 for b in buckets)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonAndConvenience:
|
||||||
|
"""Tests for singleton pattern and convenience functions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_safety_metrics_returns_same_instance(self):
|
||||||
|
"""Test get_safety_metrics returns singleton."""
|
||||||
|
# Reset the module-level singleton for this test
|
||||||
|
import app.services.safety.metrics.collector as collector_module
|
||||||
|
|
||||||
|
collector_module._metrics = None
|
||||||
|
|
||||||
|
m1 = await get_safety_metrics()
|
||||||
|
m2 = await get_safety_metrics()
|
||||||
|
|
||||||
|
assert m1 is m2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_validation_convenience(self):
|
||||||
|
"""Test record_validation convenience function."""
|
||||||
|
import app.services.safety.metrics.collector as collector_module
|
||||||
|
|
||||||
|
collector_module._metrics = None # Reset
|
||||||
|
|
||||||
|
await record_validation("allow")
|
||||||
|
await record_validation("deny", agent_id="test-agent")
|
||||||
|
|
||||||
|
metrics = await get_safety_metrics()
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
|
||||||
|
assert summary["total_validations"] == 2
|
||||||
|
assert summary["denied_validations"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_mcp_call_convenience(self):
|
||||||
|
"""Test record_mcp_call convenience function."""
|
||||||
|
import app.services.safety.metrics.collector as collector_module
|
||||||
|
|
||||||
|
collector_module._metrics = None # Reset
|
||||||
|
|
||||||
|
await record_mcp_call("search_knowledge", success=True, latency_ms=50)
|
||||||
|
await record_mcp_call("run_code", success=False, latency_ms=100)
|
||||||
|
|
||||||
|
metrics = await get_safety_metrics()
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
|
||||||
|
assert summary["mcp_calls"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrency:
|
||||||
|
"""Tests for concurrent metric updates."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_counter_increments(self):
|
||||||
|
"""Test concurrent counter increments are safe."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
async def increment_many():
|
||||||
|
for _ in range(100):
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
|
||||||
|
# Run 10 concurrent tasks each incrementing 100 times
|
||||||
|
await asyncio.gather(*[increment_many() for _ in range(10)])
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 1000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_gauge_updates(self):
|
||||||
|
"""Test concurrent gauge updates are safe."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
async def update_gauge(value):
|
||||||
|
await metrics.set_pending_approvals(value)
|
||||||
|
|
||||||
|
# Run concurrent gauge updates
|
||||||
|
await asyncio.gather(*[update_gauge(i) for i in range(100)])
|
||||||
|
|
||||||
|
# Final value should be one of the updates (last one wins)
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert 0 <= summary["pending_approvals"] < 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_histogram_observations(self):
|
||||||
|
"""Test concurrent histogram observations are safe."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
async def observe_many():
|
||||||
|
for i in range(100):
|
||||||
|
await metrics.observe_validation_latency(i / 1000)
|
||||||
|
|
||||||
|
await asyncio.gather(*[observe_many() for _ in range(10)])
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
count_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "validation_latency_seconds_count"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert count_metric is not None
|
||||||
|
assert count_metric.value == 1000.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Tests for edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_very_large_counter_value(self):
|
||||||
|
"""Test handling very large counter values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
for _ in range(10000):
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 10000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_and_negative_gauge_values(self):
|
||||||
|
"""Test zero and negative gauge values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.set_budget_remaining("project", "cost", 0.0)
|
||||||
|
await metrics.set_budget_remaining("project2", "cost", -10.0)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
gauges = [m for m in all_metrics if m.name == "safety_budget_remaining"]
|
||||||
|
|
||||||
|
values = {m.labels.get("scope"): m.value for m in gauges}
|
||||||
|
assert values["project"] == 0.0
|
||||||
|
assert values["project2"] == -10.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_very_small_histogram_values(self):
|
||||||
|
"""Test very small histogram values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.observe_validation_latency(0.0001) # 0.1ms
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
sum_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "validation_latency_seconds_sum"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert sum_metric is not None
|
||||||
|
assert abs(sum_metric.value - 0.0001) < 0.00001
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_special_characters_in_labels(self):
|
||||||
|
"""Test special characters in label values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.inc_validations("allow", agent_id="agent/with/slashes")
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
counters = [m for m in all_metrics if m.name == "safety_validations_total"]
|
||||||
|
|
||||||
|
# Should have the metric with special chars
|
||||||
|
assert len(counters) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_histogram_export(self):
|
||||||
|
"""Test exporting histogram with no observations."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
# No observations, but histogram buckets should still exist
|
||||||
|
prometheus = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
assert "validation_latency_seconds" in prometheus
|
||||||
|
assert "le=" in prometheus
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prometheus_format_empty_label_value(self):
|
||||||
|
"""Test Prometheus format with empty label metrics."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.inc_approvals_granted() # Uses empty string as label
|
||||||
|
|
||||||
|
prometheus = await metrics.get_prometheus_format()
|
||||||
|
assert "safety_approvals_granted_total" in prometheus
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_resets(self):
|
||||||
|
"""Test multiple resets don't cause issues."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
await metrics.reset()
|
||||||
|
await metrics.reset()
|
||||||
|
await metrics.reset()
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 0
|
||||||
933
backend/tests/services/safety/test_permissions.py
Normal file
933
backend/tests/services/safety/test_permissions.py
Normal file
@@ -0,0 +1,933 @@
|
|||||||
|
"""Tests for Permission Manager.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- PermissionGrant: creation, expiry, matching, hierarchy
|
||||||
|
- PermissionManager: grant, revoke, check, require, list, defaults
|
||||||
|
- Edge cases: wildcards, expiration, default deny/allow
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.exceptions import PermissionDeniedError
|
||||||
|
from app.services.safety.models import (
|
||||||
|
ActionMetadata,
|
||||||
|
ActionRequest,
|
||||||
|
ActionType,
|
||||||
|
PermissionLevel,
|
||||||
|
ResourceType,
|
||||||
|
)
|
||||||
|
from app.services.safety.permissions.manager import PermissionGrant, PermissionManager
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def action_metadata() -> ActionMetadata:
|
||||||
|
"""Create standard action metadata for tests."""
|
||||||
|
return ActionMetadata(
|
||||||
|
agent_id="test-agent",
|
||||||
|
project_id="test-project",
|
||||||
|
session_id="test-session",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def permission_manager() -> PermissionManager:
|
||||||
|
"""Create a PermissionManager for testing."""
|
||||||
|
return PermissionManager(default_deny=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def permissive_manager() -> PermissionManager:
|
||||||
|
"""Create a PermissionManager with default_deny=False."""
|
||||||
|
return PermissionManager(default_deny=False)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# PermissionGrant Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionGrant:
|
||||||
|
"""Tests for the PermissionGrant class."""
|
||||||
|
|
||||||
|
def test_grant_creation(self) -> None:
|
||||||
|
"""Test basic grant creation."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
granted_by="admin",
|
||||||
|
reason="Read access to data directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.id is not None
|
||||||
|
assert grant.agent_id == "agent-1"
|
||||||
|
assert grant.resource_pattern == "/data/*"
|
||||||
|
assert grant.resource_type == ResourceType.FILE
|
||||||
|
assert grant.level == PermissionLevel.READ
|
||||||
|
assert grant.granted_by == "admin"
|
||||||
|
assert grant.reason == "Read access to data directory"
|
||||||
|
assert grant.expires_at is None
|
||||||
|
assert grant.created_at is not None
|
||||||
|
|
||||||
|
def test_grant_with_expiration(self) -> None:
|
||||||
|
"""Test grant with expiration time."""
|
||||||
|
future = datetime.utcnow() + timedelta(hours=1)
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
expires_at=future,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.expires_at == future
|
||||||
|
assert grant.is_expired() is False
|
||||||
|
|
||||||
|
def test_is_expired_no_expiration(self) -> None:
|
||||||
|
"""Test is_expired with no expiration set."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.is_expired() is False
|
||||||
|
|
||||||
|
def test_is_expired_future(self) -> None:
|
||||||
|
"""Test is_expired with future expiration."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
expires_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.is_expired() is False
|
||||||
|
|
||||||
|
def test_is_expired_past(self) -> None:
|
||||||
|
"""Test is_expired with past expiration."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.is_expired() is True
|
||||||
|
|
||||||
|
def test_matches_exact(self) -> None:
|
||||||
|
"""Test matching with exact pattern."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
|
||||||
|
assert grant.matches("/data/other.txt", ResourceType.FILE) is False
|
||||||
|
|
||||||
|
def test_matches_wildcard(self) -> None:
|
||||||
|
"""Test matching with wildcard pattern."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
|
||||||
|
# fnmatch's * matches everything including /
|
||||||
|
assert grant.matches("/data/subdir/file.txt", ResourceType.FILE) is True
|
||||||
|
assert grant.matches("/other/file.txt", ResourceType.FILE) is False
|
||||||
|
|
||||||
|
def test_matches_recursive_wildcard(self) -> None:
|
||||||
|
"""Test matching with recursive pattern."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/**",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fnmatch treats ** similar to * - both match everything including /
|
||||||
|
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
|
||||||
|
assert grant.matches("/data/subdir/file.txt", ResourceType.FILE) is True
|
||||||
|
|
||||||
|
def test_matches_wrong_resource_type(self) -> None:
|
||||||
|
"""Test matching fails with wrong resource type."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Same pattern but different resource type
|
||||||
|
assert grant.matches("/data/table", ResourceType.DATABASE) is False
|
||||||
|
|
||||||
|
def test_allows_hierarchy(self) -> None:
|
||||||
|
"""Test permission level hierarchy."""
|
||||||
|
admin_grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.ADMIN,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ADMIN allows all levels
|
||||||
|
assert admin_grant.allows(PermissionLevel.NONE) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.READ) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.WRITE) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.EXECUTE) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.DELETE) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.ADMIN) is True
|
||||||
|
|
||||||
|
def test_allows_read_only(self) -> None:
|
||||||
|
"""Test READ grant only allows READ and NONE."""
|
||||||
|
read_grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert read_grant.allows(PermissionLevel.NONE) is True
|
||||||
|
assert read_grant.allows(PermissionLevel.READ) is True
|
||||||
|
assert read_grant.allows(PermissionLevel.WRITE) is False
|
||||||
|
assert read_grant.allows(PermissionLevel.EXECUTE) is False
|
||||||
|
assert read_grant.allows(PermissionLevel.DELETE) is False
|
||||||
|
assert read_grant.allows(PermissionLevel.ADMIN) is False
|
||||||
|
|
||||||
|
def test_allows_write_includes_read(self) -> None:
|
||||||
|
"""Test WRITE grant includes READ."""
|
||||||
|
write_grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert write_grant.allows(PermissionLevel.READ) is True
|
||||||
|
assert write_grant.allows(PermissionLevel.WRITE) is True
|
||||||
|
assert write_grant.allows(PermissionLevel.EXECUTE) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# PermissionManager Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionManager:
|
||||||
|
"""Tests for the PermissionManager class."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grant_creates_permission(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test granting a permission."""
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
granted_by="admin",
|
||||||
|
reason="Read access",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.id is not None
|
||||||
|
assert grant.agent_id == "agent-1"
|
||||||
|
assert grant.resource_pattern == "/data/*"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grant_with_duration(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test granting a temporary permission."""
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
duration_seconds=3600, # 1 hour
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.expires_at is not None
|
||||||
|
assert grant.is_expired() is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_by_id(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test revoking a grant by ID."""
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await permission_manager.revoke(grant.id)
|
||||||
|
assert success is True
|
||||||
|
|
||||||
|
# Verify grant is removed
|
||||||
|
grants = await permission_manager.list_grants(agent_id="agent-1")
|
||||||
|
assert len(grants) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_nonexistent(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test revoking a non-existent grant."""
|
||||||
|
success = await permission_manager.revoke("nonexistent-id")
|
||||||
|
assert success is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_all_for_agent(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test revoking all permissions for an agent."""
|
||||||
|
# Grant multiple permissions
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/api/*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-2",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
revoked = await permission_manager.revoke_all("agent-1")
|
||||||
|
assert revoked == 2
|
||||||
|
|
||||||
|
# Verify agent-1 grants are gone
|
||||||
|
grants = await permission_manager.list_grants(agent_id="agent-1")
|
||||||
|
assert len(grants) == 0
|
||||||
|
|
||||||
|
# Verify agent-2 grant remains
|
||||||
|
grants = await permission_manager.list_grants(agent_id="agent-2")
|
||||||
|
assert len(grants) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_all_no_grants(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test revoking all when no grants exist."""
|
||||||
|
revoked = await permission_manager.revoke_all("nonexistent-agent")
|
||||||
|
assert revoked == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_granted(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test checking a granted permission."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_denied_default_deny(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test checking denied with default_deny=True."""
|
||||||
|
# No grants, should be denied
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_uses_default_permissions(
|
||||||
|
self,
|
||||||
|
permissive_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that default permissions apply when default_deny=False."""
|
||||||
|
# No explicit grants, but FILE default is READ
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
# But WRITE should fail
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_shell_denied_by_default(
|
||||||
|
self,
|
||||||
|
permissive_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test SHELL is denied by default (NONE level)."""
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="rm -rf /",
|
||||||
|
resource_type=ResourceType.SHELL,
|
||||||
|
required_level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_expired_grant_ignored(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that expired grants are ignored in checks."""
|
||||||
|
# Create an already-expired grant
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
duration_seconds=1, # Very short
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manually expire it
|
||||||
|
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_insufficient_level(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test check fails when grant level is insufficient."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to get WRITE access with only READ grant
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_action_file_read(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test check_action for file read."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="test-agent",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
resource="/data/file.txt",
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check_action(action)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_action_file_write(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test check_action for file write."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="test-agent",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
resource="/data/file.txt",
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check_action(action)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_action_uses_tool_name_as_resource(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test check_action uses tool_name when resource is None."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="test-agent",
|
||||||
|
resource_pattern="search_*",
|
||||||
|
resource_type=ResourceType.CUSTOM,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.TOOL_CALL,
|
||||||
|
tool_name="search_documents",
|
||||||
|
resource=None,
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check_action(action)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_permission_granted(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test require_permission doesn't raise when granted."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await permission_manager.require_permission(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_permission_denied(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test require_permission raises when denied."""
|
||||||
|
with pytest.raises(PermissionDeniedError) as exc_info:
|
||||||
|
await permission_manager.require_permission(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/secret/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "/secret/file.txt" in str(exc_info.value)
|
||||||
|
assert exc_info.value.agent_id == "agent-1"
|
||||||
|
assert exc_info.value.required_permission == "read"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_grants_all(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing all grants."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-2",
|
||||||
|
resource_pattern="/api/*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
grants = await permission_manager.list_grants()
|
||||||
|
assert len(grants) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_grants_by_agent(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing grants filtered by agent."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-2",
|
||||||
|
resource_pattern="/api/*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
grants = await permission_manager.list_grants(agent_id="agent-1")
|
||||||
|
assert len(grants) == 1
|
||||||
|
assert grants[0].agent_id == "agent-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_grants_by_resource_type(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing grants filtered by resource type."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/api/*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
grants = await permission_manager.list_grants(resource_type=ResourceType.FILE)
|
||||||
|
assert len(grants) == 1
|
||||||
|
assert grants[0].resource_type == ResourceType.FILE
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_grants_excludes_expired(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that list_grants excludes expired grants."""
|
||||||
|
# Create expired grant
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/old/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
duration_seconds=1,
|
||||||
|
)
|
||||||
|
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
|
||||||
|
|
||||||
|
# Create valid grant
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/new/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
grants = await permission_manager.list_grants()
|
||||||
|
assert len(grants) == 1
|
||||||
|
assert grants[0].resource_pattern == "/new/*"
|
||||||
|
|
||||||
|
def test_set_default_permission(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""Test setting default permission level."""
|
||||||
|
manager = PermissionManager(default_deny=False)
|
||||||
|
|
||||||
|
# Default for SHELL is NONE
|
||||||
|
assert manager._default_permissions[ResourceType.SHELL] == PermissionLevel.NONE
|
||||||
|
|
||||||
|
# Change it
|
||||||
|
manager.set_default_permission(ResourceType.SHELL, PermissionLevel.EXECUTE)
|
||||||
|
assert (
|
||||||
|
manager._default_permissions[ResourceType.SHELL] == PermissionLevel.EXECUTE
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_default_permission_affects_checks(
|
||||||
|
self,
|
||||||
|
permissive_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that changing default permissions affects checks."""
|
||||||
|
# Initially SHELL is NONE
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="ls",
|
||||||
|
resource_type=ResourceType.SHELL,
|
||||||
|
required_level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
# Change default
|
||||||
|
permissive_manager.set_default_permission(
|
||||||
|
ResourceType.SHELL, PermissionLevel.EXECUTE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now should be allowed
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="ls",
|
||||||
|
resource_type=ResourceType.SHELL,
|
||||||
|
required_level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Edge Cases
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionEdgeCases:
|
||||||
|
"""Edge cases that could reveal hidden bugs."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_matching_grants(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test when multiple grants match - first sufficient one wins."""
|
||||||
|
# Grant READ on all files
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also grant WRITE on specific path
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/writable/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write on writable path should work
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/writable/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wildcard_all_pattern(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test * pattern matches everything."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.ADMIN,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/any/path/anywhere/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.DELETE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fnmatch's * matches everything including /
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_question_mark_wildcard(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test ? wildcard matches single character."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="file?.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="file1.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="file10.txt", # Two characters, won't match
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_grant_revoke(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test concurrent grant and revoke operations."""
|
||||||
|
|
||||||
|
async def grant_many():
|
||||||
|
grants = []
|
||||||
|
for i in range(10):
|
||||||
|
g = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern=f"/path{i}/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
grants.append(g)
|
||||||
|
return grants
|
||||||
|
|
||||||
|
async def revoke_many(grants):
|
||||||
|
for g in grants:
|
||||||
|
await permission_manager.revoke(g.id)
|
||||||
|
|
||||||
|
grants = await grant_many()
|
||||||
|
await revoke_many(grants)
|
||||||
|
|
||||||
|
# All should be revoked
|
||||||
|
remaining = await permission_manager.list_grants(agent_id="agent-1")
|
||||||
|
assert len(remaining) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_action_with_no_resource_or_tool(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test check_action when both resource and tool_name are None."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="test-agent",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.LLM,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.LLM_CALL,
|
||||||
|
resource=None,
|
||||||
|
tool_name=None,
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should use "*" as fallback
|
||||||
|
allowed = await permission_manager.check_action(action)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_expired_called_on_check(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that expired grants are cleaned up during check."""
|
||||||
|
# Create expired grant
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/old/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
duration_seconds=1,
|
||||||
|
)
|
||||||
|
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
|
||||||
|
|
||||||
|
# Create valid grant
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/new/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run a check - this should trigger cleanup
|
||||||
|
await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/new/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now verify expired grant was cleaned up
|
||||||
|
async with permission_manager._lock:
|
||||||
|
assert len(permission_manager._grants) == 1
|
||||||
|
assert permission_manager._grants[0].resource_pattern == "/new/*"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_wrong_agent_id(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test check fails for different agent."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Different agent should not have access
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-2",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
823
backend/tests/services/safety/test_rollback.py
Normal file
823
backend/tests/services/safety/test_rollback.py
Normal file
@@ -0,0 +1,823 @@
|
|||||||
|
"""Tests for Rollback Manager.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- FileCheckpoint: state storage
|
||||||
|
- RollbackManager: checkpoint, rollback, cleanup
|
||||||
|
- TransactionContext: auto-rollback, commit, manual rollback
|
||||||
|
- Edge cases: non-existent files, partial failures, expiration
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.exceptions import RollbackError
|
||||||
|
from app.services.safety.models import (
|
||||||
|
ActionMetadata,
|
||||||
|
ActionRequest,
|
||||||
|
ActionType,
|
||||||
|
CheckpointType,
|
||||||
|
)
|
||||||
|
from app.services.safety.rollback.manager import (
|
||||||
|
FileCheckpoint,
|
||||||
|
RollbackManager,
|
||||||
|
TransactionContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def action_metadata() -> ActionMetadata:
|
||||||
|
"""Create standard action metadata for tests."""
|
||||||
|
return ActionMetadata(
|
||||||
|
agent_id="test-agent",
|
||||||
|
project_id="test-project",
|
||||||
|
session_id="test-session",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def action_request(action_metadata: ActionMetadata) -> ActionRequest:
|
||||||
|
"""Create a standard action request for tests."""
|
||||||
|
return ActionRequest(
|
||||||
|
id="action-123",
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
tool_name="file_write",
|
||||||
|
resource="/tmp/test_file.txt", # noqa: S108
|
||||||
|
metadata=action_metadata,
|
||||||
|
is_destructive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def rollback_manager() -> RollbackManager:
|
||||||
|
"""Create a RollbackManager for testing."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
with patch("app.services.safety.rollback.manager.get_safety_config") as mock:
|
||||||
|
mock.return_value = MagicMock(
|
||||||
|
checkpoint_dir=tmpdir,
|
||||||
|
checkpoint_retention_hours=24,
|
||||||
|
)
|
||||||
|
manager = RollbackManager(checkpoint_dir=tmpdir, retention_hours=24)
|
||||||
|
yield manager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir() -> Path:
|
||||||
|
"""Create a temporary directory for file operations."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# FileCheckpoint Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileCheckpoint:
|
||||||
|
"""Tests for the FileCheckpoint class."""
|
||||||
|
|
||||||
|
def test_file_checkpoint_creation(self) -> None:
|
||||||
|
"""Test creating a file checkpoint."""
|
||||||
|
fc = FileCheckpoint(
|
||||||
|
checkpoint_id="cp-123",
|
||||||
|
file_path="/path/to/file.txt",
|
||||||
|
original_content=b"original content",
|
||||||
|
existed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fc.checkpoint_id == "cp-123"
|
||||||
|
assert fc.file_path == "/path/to/file.txt"
|
||||||
|
assert fc.original_content == b"original content"
|
||||||
|
assert fc.existed is True
|
||||||
|
assert fc.created_at is not None
|
||||||
|
|
||||||
|
def test_file_checkpoint_nonexistent_file(self) -> None:
|
||||||
|
"""Test checkpoint for non-existent file."""
|
||||||
|
fc = FileCheckpoint(
|
||||||
|
checkpoint_id="cp-123",
|
||||||
|
file_path="/path/to/new_file.txt",
|
||||||
|
original_content=None,
|
||||||
|
existed=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fc.original_content is None
|
||||||
|
assert fc.existed is False
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# RollbackManager Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestRollbackManager:
|
||||||
|
"""Tests for the RollbackManager class."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a checkpoint."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(
|
||||||
|
action=action_request,
|
||||||
|
checkpoint_type=CheckpointType.FILE,
|
||||||
|
description="Test checkpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert checkpoint.id is not None
|
||||||
|
assert checkpoint.action_id == action_request.id
|
||||||
|
assert checkpoint.checkpoint_type == CheckpointType.FILE
|
||||||
|
assert checkpoint.description == "Test checkpoint"
|
||||||
|
assert checkpoint.expires_at is not None
|
||||||
|
assert checkpoint.is_valid is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpoint_default_description(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpoint with default description."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
assert "file_write" in checkpoint.description
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_file_exists(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing an existing file."""
|
||||||
|
# Create a file
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original content")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Verify checkpoint was stored
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
||||||
|
assert len(file_checkpoints) == 1
|
||||||
|
assert file_checkpoints[0].existed is True
|
||||||
|
assert file_checkpoints[0].original_content == b"original content"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_file_not_exists(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing a non-existent file."""
|
||||||
|
test_file = temp_dir / "new_file.txt"
|
||||||
|
assert not test_file.exists()
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Verify checkpoint was stored
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
||||||
|
assert len(file_checkpoints) == 1
|
||||||
|
assert file_checkpoints[0].existed is False
|
||||||
|
assert file_checkpoints[0].original_content is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_files_multiple(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing multiple files."""
|
||||||
|
# Create files
|
||||||
|
file1 = temp_dir / "file1.txt"
|
||||||
|
file2 = temp_dir / "file2.txt"
|
||||||
|
file1.write_text("content 1")
|
||||||
|
file2.write_text("content 2")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_files(
|
||||||
|
checkpoint.id,
|
||||||
|
[str(file1), str(file2)],
|
||||||
|
)
|
||||||
|
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
||||||
|
assert len(file_checkpoints) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_restore_modified_file(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback restores modified file content."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original content")
|
||||||
|
|
||||||
|
# Create checkpoint
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Modify file
|
||||||
|
test_file.write_text("modified content")
|
||||||
|
assert test_file.read_text() == "modified content"
|
||||||
|
|
||||||
|
# Rollback
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert len(result.actions_rolled_back) == 1
|
||||||
|
assert test_file.read_text() == "original content"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_delete_new_file(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback deletes file that didn't exist before."""
|
||||||
|
test_file = temp_dir / "new_file.txt"
|
||||||
|
assert not test_file.exists()
|
||||||
|
|
||||||
|
# Create checkpoint before file exists
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Create the file
|
||||||
|
test_file.write_text("new content")
|
||||||
|
assert test_file.exists()
|
||||||
|
|
||||||
|
# Rollback
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert not test_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_not_found(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback with non-existent checkpoint."""
|
||||||
|
with pytest.raises(RollbackError) as exc_info:
|
||||||
|
await rollback_manager.rollback("nonexistent-id")
|
||||||
|
|
||||||
|
assert "not found" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_invalid_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback with invalidated checkpoint."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Rollback once (invalidates checkpoint)
|
||||||
|
await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
# Try to rollback again
|
||||||
|
with pytest.raises(RollbackError) as exc_info:
|
||||||
|
await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert "no longer valid" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discard_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test discarding a checkpoint."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
result = await rollback_manager.discard_checkpoint(checkpoint.id)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Verify it's gone
|
||||||
|
cp = await rollback_manager.get_checkpoint(checkpoint.id)
|
||||||
|
assert cp is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discard_checkpoint_nonexistent(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test discarding a non-existent checkpoint."""
|
||||||
|
result = await rollback_manager.discard_checkpoint("nonexistent-id")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting a checkpoint by ID."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
retrieved = await rollback_manager.get_checkpoint(checkpoint.id)
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.id == checkpoint.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_checkpoint_nonexistent(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting a non-existent checkpoint."""
|
||||||
|
retrieved = await rollback_manager.get_checkpoint("nonexistent-id")
|
||||||
|
assert retrieved is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing checkpoints."""
|
||||||
|
await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints()
|
||||||
|
assert len(checkpoints) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints_by_action(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing checkpoints filtered by action."""
|
||||||
|
action1 = ActionRequest(
|
||||||
|
id="action-1",
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
action2 = ActionRequest(
|
||||||
|
id="action-2",
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await rollback_manager.create_checkpoint(action=action1)
|
||||||
|
await rollback_manager.create_checkpoint(action=action2)
|
||||||
|
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints(action_id="action-1")
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
assert checkpoints[0].action_id == "action-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints_excludes_expired(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test list_checkpoints excludes expired by default."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
# Manually expire it
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
rollback_manager._checkpoints[checkpoint.id].expires_at = (
|
||||||
|
datetime.utcnow() - timedelta(hours=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints()
|
||||||
|
assert len(checkpoints) == 0
|
||||||
|
|
||||||
|
# With include_expired=True
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints(include_expired=True)
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_expired(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test cleaning up expired checkpoints."""
|
||||||
|
# Create checkpoints
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("content")
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Expire it
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
rollback_manager._checkpoints[checkpoint.id].expires_at = (
|
||||||
|
datetime.utcnow() - timedelta(hours=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
count = await rollback_manager.cleanup_expired()
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
# Verify it's gone
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
assert checkpoint.id not in rollback_manager._checkpoints
|
||||||
|
assert checkpoint.id not in rollback_manager._file_checkpoints
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# TransactionContext Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransactionContext:
|
||||||
|
"""Tests for the TransactionContext class."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_creates_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test that entering context creates a checkpoint."""
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
assert tx.checkpoint_id is not None
|
||||||
|
|
||||||
|
# Verify checkpoint exists
|
||||||
|
cp = await rollback_manager.get_checkpoint(tx.checkpoint_id)
|
||||||
|
assert cp is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_checkpoint_file(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing files through context."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
|
||||||
|
# Modify file
|
||||||
|
test_file.write_text("modified")
|
||||||
|
|
||||||
|
# Manual rollback
|
||||||
|
result = await tx.rollback()
|
||||||
|
assert result is not None
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
assert test_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_checkpoint_files(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing multiple files through context."""
|
||||||
|
file1 = temp_dir / "file1.txt"
|
||||||
|
file2 = temp_dir / "file2.txt"
|
||||||
|
file1.write_text("content 1")
|
||||||
|
file2.write_text("content 2")
|
||||||
|
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_files([str(file1), str(file2)])
|
||||||
|
|
||||||
|
cp_id = tx.checkpoint_id
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
file_cps = rollback_manager._file_checkpoints.get(cp_id, [])
|
||||||
|
assert len(file_cps) == 2
|
||||||
|
|
||||||
|
tx.commit()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_auto_rollback_on_exception(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test auto-rollback when exception occurs."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
raise ValueError("Simulated error")
|
||||||
|
|
||||||
|
# Should have been rolled back
|
||||||
|
assert test_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_commit_prevents_rollback(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test that commit prevents auto-rollback."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
tx.commit()
|
||||||
|
raise ValueError("Simulated error after commit")
|
||||||
|
|
||||||
|
# Should NOT have been rolled back
|
||||||
|
assert test_file.read_text() == "modified"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_discards_checkpoint_on_commit(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test that checkpoint is discarded after successful commit."""
|
||||||
|
checkpoint_id = None
|
||||||
|
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
checkpoint_id = tx.checkpoint_id
|
||||||
|
tx.commit()
|
||||||
|
|
||||||
|
# Checkpoint should be discarded
|
||||||
|
cp = await rollback_manager.get_checkpoint(checkpoint_id)
|
||||||
|
assert cp is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_no_auto_rollback_when_disabled(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test that auto_rollback=False disables auto-rollback."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with TransactionContext(
|
||||||
|
rollback_manager,
|
||||||
|
action_request,
|
||||||
|
auto_rollback=False,
|
||||||
|
) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
raise ValueError("Simulated error")
|
||||||
|
|
||||||
|
# Should NOT have been rolled back
|
||||||
|
assert test_file.read_text() == "modified"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_manual_rollback(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test manual rollback within context."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
|
||||||
|
# Manual rollback
|
||||||
|
result = await tx.rollback()
|
||||||
|
assert result is not None
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
assert test_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_rollback_without_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback when checkpoint is None."""
|
||||||
|
tx = TransactionContext(rollback_manager, action_request)
|
||||||
|
# Don't enter context, so _checkpoint is None
|
||||||
|
result = await tx.rollback()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_checkpoint_file_without_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpoint_file when checkpoint is None (no-op)."""
|
||||||
|
tx = TransactionContext(rollback_manager, action_request)
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("content")
|
||||||
|
|
||||||
|
# Should not raise - just a no-op
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
await tx.checkpoint_files([str(test_file)])
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Edge Cases
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestRollbackEdgeCases:
|
||||||
|
"""Edge cases that could reveal hidden bugs."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_file_for_unknown_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing file for non-existent checkpoint."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("content")
|
||||||
|
|
||||||
|
# Should create the list if it doesn't exist
|
||||||
|
await rollback_manager.checkpoint_file("unknown-checkpoint", str(test_file))
|
||||||
|
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
assert "unknown-checkpoint" in rollback_manager._file_checkpoints
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_with_partial_failure(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback when some files fail to restore."""
|
||||||
|
file1 = temp_dir / "file1.txt"
|
||||||
|
file1.write_text("original 1")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(file1))
|
||||||
|
|
||||||
|
# Add a file checkpoint with a path that will fail
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
# Create a checkpoint for a file in a non-writable location
|
||||||
|
bad_fc = FileCheckpoint(
|
||||||
|
checkpoint_id=checkpoint.id,
|
||||||
|
file_path="/nonexistent/path/file.txt",
|
||||||
|
original_content=b"content",
|
||||||
|
existed=True,
|
||||||
|
)
|
||||||
|
rollback_manager._file_checkpoints[checkpoint.id].append(bad_fc)
|
||||||
|
|
||||||
|
# Rollback - partial failure expected
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert len(result.actions_rolled_back) == 1
|
||||||
|
assert len(result.failed_actions) == 1
|
||||||
|
assert "Failed to rollback" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_file_creates_parent_dirs(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test that rollback creates parent directories if needed."""
|
||||||
|
nested_file = temp_dir / "subdir" / "nested" / "file.txt"
|
||||||
|
nested_file.parent.mkdir(parents=True)
|
||||||
|
nested_file.write_text("original")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(nested_file))
|
||||||
|
|
||||||
|
# Delete the entire directory structure
|
||||||
|
nested_file.unlink()
|
||||||
|
(temp_dir / "subdir" / "nested").rmdir()
|
||||||
|
(temp_dir / "subdir").rmdir()
|
||||||
|
|
||||||
|
# Rollback should recreate
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert nested_file.exists()
|
||||||
|
assert nested_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_file_already_correct(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback when file already has correct content."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Don't modify file - rollback should still succeed
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert test_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_with_none_expires_at(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test list_checkpoints handles None expires_at."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
# Set expires_at to None
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
rollback_manager._checkpoints[checkpoint.id].expires_at = None
|
||||||
|
|
||||||
|
# Should still be listed
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints()
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_rollback_failure_logged(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test that auto-rollback failure is logged, not raised."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
rollback_manager, "rollback", side_effect=Exception("Rollback failed!")
|
||||||
|
):
|
||||||
|
with patch("app.services.safety.rollback.manager.logger") as mock_logger:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with TransactionContext(
|
||||||
|
rollback_manager, action_request
|
||||||
|
) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
raise ValueError("Original error")
|
||||||
|
|
||||||
|
# Rollback error should be logged
|
||||||
|
mock_logger.error.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_checkpoints_same_action(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating multiple checkpoints for the same action."""
|
||||||
|
cp1 = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
cp2 = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
assert cp1.id != cp2.id
|
||||||
|
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints(
|
||||||
|
action_id=action_request.id
|
||||||
|
)
|
||||||
|
assert len(checkpoints) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_expired_with_no_expired(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test cleanup when no checkpoints are expired."""
|
||||||
|
await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
count = await rollback_manager.cleanup_expired()
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
# Checkpoint should still exist
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints()
|
||||||
|
assert len(checkpoints) == 1
|
||||||
@@ -363,6 +363,365 @@ class TestValidationBatch:
|
|||||||
assert results[1].decision == SafetyDecision.DENY
|
assert results[1].decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidationCache:
|
||||||
|
"""Tests for ValidationCache class."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_miss(self) -> None:
|
||||||
|
"""Test cache miss."""
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||||
|
result = await cache.get("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_hit(self) -> None:
|
||||||
|
"""Test cache hit."""
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||||
|
vr = ValidationResult(
|
||||||
|
action_id="action-1",
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
applied_rules=[],
|
||||||
|
reasons=["test"],
|
||||||
|
)
|
||||||
|
await cache.set("key1", vr)
|
||||||
|
|
||||||
|
result = await cache.get("key1")
|
||||||
|
assert result is not None
|
||||||
|
assert result.action_id == "action-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_ttl_expiry(self) -> None:
|
||||||
|
"""Test cache TTL expiry."""
|
||||||
|
import time
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=1)
|
||||||
|
vr = ValidationResult(
|
||||||
|
action_id="action-1",
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
applied_rules=[],
|
||||||
|
reasons=["test"],
|
||||||
|
)
|
||||||
|
await cache.set("key1", vr)
|
||||||
|
|
||||||
|
# Advance time past TTL
|
||||||
|
with patch("time.time", return_value=time.time() + 2):
|
||||||
|
result = await cache.get("key1")
|
||||||
|
assert result is None # Should be expired
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_eviction_on_full(self) -> None:
|
||||||
|
"""Test cache eviction when full."""
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=2, ttl_seconds=60)
|
||||||
|
|
||||||
|
vr1 = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
|
||||||
|
vr2 = ValidationResult(action_id="a2", decision=SafetyDecision.ALLOW)
|
||||||
|
vr3 = ValidationResult(action_id="a3", decision=SafetyDecision.ALLOW)
|
||||||
|
|
||||||
|
await cache.set("key1", vr1)
|
||||||
|
await cache.set("key2", vr2)
|
||||||
|
await cache.set("key3", vr3) # Should evict key1
|
||||||
|
|
||||||
|
# key1 should be evicted
|
||||||
|
assert await cache.get("key1") is None
|
||||||
|
assert await cache.get("key2") is not None
|
||||||
|
assert await cache.get("key3") is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_update_existing_key(self) -> None:
|
||||||
|
"""Test updating existing key in cache."""
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||||
|
|
||||||
|
vr1 = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
|
||||||
|
vr2 = ValidationResult(action_id="a1-updated", decision=SafetyDecision.DENY)
|
||||||
|
|
||||||
|
await cache.set("key1", vr1)
|
||||||
|
await cache.set("key1", vr2) # Should update, not add
|
||||||
|
|
||||||
|
result = await cache.get("key1")
|
||||||
|
assert result is not None
|
||||||
|
assert result.action_id == "a1" # Still old value since we move_to_end
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_clear(self) -> None:
|
||||||
|
"""Test clearing cache."""
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||||
|
|
||||||
|
vr = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
|
||||||
|
await cache.set("key1", vr)
|
||||||
|
await cache.set("key2", vr)
|
||||||
|
|
||||||
|
await cache.clear()
|
||||||
|
|
||||||
|
assert await cache.get("key1") is None
|
||||||
|
assert await cache.get("key2") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidatorCaching:
|
||||||
|
"""Tests for validator caching functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit(self) -> None:
|
||||||
|
"""Test that cache is used for repeated validations."""
|
||||||
|
validator = ActionValidator(cache_enabled=True, cache_ttl=60)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
resource="/tmp/test.txt", # noqa: S108
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call populates cache
|
||||||
|
result1 = await validator.validate(action)
|
||||||
|
# Second call should use cache
|
||||||
|
result2 = await validator.validate(action)
|
||||||
|
|
||||||
|
assert result1.decision == result2.decision
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_cache(self) -> None:
|
||||||
|
"""Test clearing the validation cache."""
|
||||||
|
validator = ActionValidator(cache_enabled=True)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await validator.validate(action)
|
||||||
|
await validator.clear_cache()
|
||||||
|
|
||||||
|
# Cache should be empty now (no error)
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuleMatching:
|
||||||
|
"""Tests for rule matching edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_action_type_mismatch(self) -> None:
|
||||||
|
"""Test that rule doesn't match when action type doesn't match."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
validator.add_rule(
|
||||||
|
ValidationRule(
|
||||||
|
name="file_only",
|
||||||
|
action_types=[ActionType.FILE_READ],
|
||||||
|
decision=SafetyDecision.DENY,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.SHELL_COMMAND, # Different type
|
||||||
|
tool_name="shell_exec",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_pattern_no_tool_name(self) -> None:
|
||||||
|
"""Test rule with tool pattern when action has no tool_name."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
validator.add_rule(
|
||||||
|
create_deny_rule(
|
||||||
|
name="deny_files",
|
||||||
|
tool_patterns=["file_*"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name=None, # No tool name
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resource_pattern_no_resource(self) -> None:
|
||||||
|
"""Test rule with resource pattern when action has no resource."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
validator.add_rule(
|
||||||
|
create_deny_rule(
|
||||||
|
name="deny_secrets",
|
||||||
|
resource_patterns=["/secret/*"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
resource=None, # No resource
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resource_pattern_no_match(self) -> None:
|
||||||
|
"""Test rule with resource pattern that doesn't match."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
validator.add_rule(
|
||||||
|
create_deny_rule(
|
||||||
|
name="deny_secrets",
|
||||||
|
resource_patterns=["/secret/*"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
resource="/public/file.txt", # Doesn't match
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW # Pattern didn't match
|
||||||
|
|
||||||
|
|
||||||
|
class TestPolicyLoading:
|
||||||
|
"""Tests for policy loading edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_rules_from_policy_with_validation_rules(self) -> None:
|
||||||
|
"""Test loading policy with explicit validation rules."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
|
||||||
|
rule = ValidationRule(
|
||||||
|
name="policy_rule",
|
||||||
|
tool_patterns=["test_*"],
|
||||||
|
decision=SafetyDecision.DENY,
|
||||||
|
reason="From policy",
|
||||||
|
)
|
||||||
|
policy = SafetyPolicy(
|
||||||
|
name="test",
|
||||||
|
validation_rules=[rule],
|
||||||
|
require_approval_for=[], # Clear defaults
|
||||||
|
denied_tools=[], # Clear defaults
|
||||||
|
)
|
||||||
|
|
||||||
|
validator.load_rules_from_policy(policy)
|
||||||
|
|
||||||
|
assert len(validator._rules) == 1
|
||||||
|
assert validator._rules[0].name == "policy_rule"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_approval_all_pattern(self) -> None:
|
||||||
|
"""Test loading policy with * approval pattern (all actions)."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
|
||||||
|
policy = SafetyPolicy(
|
||||||
|
name="test",
|
||||||
|
require_approval_for=["*"], # All actions require approval
|
||||||
|
denied_tools=[], # Clear defaults
|
||||||
|
)
|
||||||
|
|
||||||
|
validator.load_rules_from_policy(policy)
|
||||||
|
|
||||||
|
approval_rules = [
|
||||||
|
r for r in validator._rules if r.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||||
|
]
|
||||||
|
assert len(approval_rules) == 1
|
||||||
|
assert approval_rules[0].name == "require_approval_all"
|
||||||
|
assert approval_rules[0].action_types == list(ActionType)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_with_policy_loads_rules(self) -> None:
|
||||||
|
"""Test that validate() loads rules from policy if none exist."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
|
||||||
|
policy = SafetyPolicy(
|
||||||
|
name="test",
|
||||||
|
denied_tools=["dangerous_*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.SHELL_COMMAND,
|
||||||
|
tool_name="dangerous_exec",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate with policy - should load rules
|
||||||
|
result = await validator.validate(action, policy=policy)
|
||||||
|
|
||||||
|
assert result.decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheKeyGeneration:
|
||||||
|
"""Tests for cache key generation."""
|
||||||
|
|
||||||
|
def test_get_cache_key(self) -> None:
|
||||||
|
"""Test cache key generation."""
|
||||||
|
validator = ActionValidator(cache_enabled=True)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(
|
||||||
|
agent_id="test-agent",
|
||||||
|
autonomy_level=AutonomyLevel.MILESTONE,
|
||||||
|
)
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
resource="/tmp/test.txt", # noqa: S108
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
key = validator._get_cache_key(action)
|
||||||
|
|
||||||
|
assert "file_read" in key
|
||||||
|
assert "file_read" in key
|
||||||
|
assert "/tmp/test.txt" in key # noqa: S108
|
||||||
|
assert "test-agent" in key
|
||||||
|
assert "milestone" in key
|
||||||
|
|
||||||
|
def test_get_cache_key_no_resource(self) -> None:
|
||||||
|
"""Test cache key generation without resource."""
|
||||||
|
validator = ActionValidator(cache_enabled=True)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="agent-1")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.SHELL_COMMAND,
|
||||||
|
tool_name="shell_exec",
|
||||||
|
resource=None,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
key = validator._get_cache_key(action)
|
||||||
|
|
||||||
|
# Should not error with None resource
|
||||||
|
assert "shell" in key
|
||||||
|
assert "agent-1" in key
|
||||||
|
|
||||||
|
|
||||||
class TestHelperFunctions:
|
class TestHelperFunctions:
|
||||||
"""Tests for rule creation helper functions."""
|
"""Tests for rule creation helper functions."""
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,80 @@ services:
|
|||||||
- app-network
|
- app-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# MCP Servers - Model Context Protocol servers for AI agent capabilities
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
mcp-llm-gateway:
|
||||||
|
# REPLACE THIS with your actual image from your container registry
|
||||||
|
image: YOUR_REGISTRY/YOUR_PROJECT_MCP_LLM_GATEWAY:latest
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
- LLM_GATEWAY_HOST=0.0.0.0
|
||||||
|
- LLM_GATEWAY_PORT=8001
|
||||||
|
- REDIS_URL=redis://redis:6379/1
|
||||||
|
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
|
- ENVIRONMENT=production
|
||||||
|
depends_on:
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
restart: unless-stopped
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cpus: '2.0'
|
||||||
|
memory: 2G
|
||||||
|
reservations:
|
||||||
|
cpus: '0.5'
|
||||||
|
memory: 512M
|
||||||
|
|
||||||
|
mcp-knowledge-base:
|
||||||
|
# REPLACE THIS with your actual image from your container registry
|
||||||
|
image: YOUR_REGISTRY/YOUR_PROJECT_MCP_KNOWLEDGE_BASE:latest
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
# KB_ prefix required by pydantic-settings config
|
||||||
|
- KB_HOST=0.0.0.0
|
||||||
|
- KB_PORT=8002
|
||||||
|
- KB_DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
|
||||||
|
- KB_REDIS_URL=redis://redis:6379/2
|
||||||
|
- KB_LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
|
- ENVIRONMENT=production
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
restart: unless-stopped
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cpus: '1.0'
|
||||||
|
memory: 1G
|
||||||
|
reservations:
|
||||||
|
cpus: '0.25'
|
||||||
|
memory: 256M
|
||||||
|
|
||||||
backend:
|
backend:
|
||||||
# REPLACE THIS with your actual image from your container registry
|
# REPLACE THIS with your actual image from your container registry
|
||||||
# Examples:
|
# Examples:
|
||||||
@@ -64,11 +138,18 @@ services:
|
|||||||
- DEBUG=false
|
- DEBUG=false
|
||||||
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
|
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
|
# MCP Server URLs
|
||||||
|
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
|
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
mcp-llm-gateway:
|
||||||
|
condition: service_healthy
|
||||||
|
mcp-knowledge-base:
|
||||||
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
@@ -92,11 +173,18 @@ services:
|
|||||||
- DATABASE_URL=${DATABASE_URL}
|
- DATABASE_URL=${DATABASE_URL}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
- CELERY_QUEUE=agent
|
- CELERY_QUEUE=agent
|
||||||
|
# MCP Server URLs (agents need access to MCP)
|
||||||
|
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
|
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
mcp-llm-gateway:
|
||||||
|
condition: service_healthy
|
||||||
|
mcp-knowledge-base:
|
||||||
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|||||||
@@ -32,6 +32,70 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# MCP Servers - Model Context Protocol servers for AI agent capabilities
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
mcp-llm-gateway:
|
||||||
|
build:
|
||||||
|
context: ./mcp-servers/llm-gateway
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
ports:
|
||||||
|
- "8001:8001"
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
- LLM_GATEWAY_HOST=0.0.0.0
|
||||||
|
- LLM_GATEWAY_PORT=8001
|
||||||
|
- REDIS_URL=redis://redis:6379/1
|
||||||
|
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
|
- ENVIRONMENT=development
|
||||||
|
depends_on:
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
mcp-knowledge-base:
|
||||||
|
build:
|
||||||
|
context: ./mcp-servers/knowledge-base
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
ports:
|
||||||
|
- "8002:8002"
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
# KB_ prefix required by pydantic-settings config
|
||||||
|
- KB_HOST=0.0.0.0
|
||||||
|
- KB_PORT=8002
|
||||||
|
- KB_DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
|
||||||
|
- KB_REDIS_URL=redis://redis:6379/2
|
||||||
|
- KB_LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
|
- ENVIRONMENT=development
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
backend:
|
backend:
|
||||||
build:
|
build:
|
||||||
context: ./backend
|
context: ./backend
|
||||||
@@ -52,11 +116,18 @@ services:
|
|||||||
- DEBUG=true
|
- DEBUG=true
|
||||||
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
|
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
|
# MCP Server URLs
|
||||||
|
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
|
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
mcp-llm-gateway:
|
||||||
|
condition: service_healthy
|
||||||
|
mcp-knowledge-base:
|
||||||
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
@@ -81,11 +152,18 @@ services:
|
|||||||
- DATABASE_URL=${DATABASE_URL}
|
- DATABASE_URL=${DATABASE_URL}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
- CELERY_QUEUE=agent
|
- CELERY_QUEUE=agent
|
||||||
|
# MCP Server URLs (agents need access to MCP)
|
||||||
|
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
|
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
mcp-llm-gateway:
|
||||||
|
condition: service_healthy
|
||||||
|
mcp-knowledge-base:
|
||||||
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
|
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
|
||||||
|
|||||||
@@ -32,6 +32,82 @@ services:
|
|||||||
- app-network
|
- app-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# MCP Servers - Model Context Protocol servers for AI agent capabilities
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
mcp-llm-gateway:
|
||||||
|
build:
|
||||||
|
context: ./mcp-servers/llm-gateway
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
- LLM_GATEWAY_HOST=0.0.0.0
|
||||||
|
- LLM_GATEWAY_PORT=8001
|
||||||
|
- REDIS_URL=redis://redis:6379/1
|
||||||
|
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
|
- ENVIRONMENT=production
|
||||||
|
depends_on:
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
restart: unless-stopped
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cpus: '2.0'
|
||||||
|
memory: 2G
|
||||||
|
reservations:
|
||||||
|
cpus: '0.5'
|
||||||
|
memory: 512M
|
||||||
|
|
||||||
|
mcp-knowledge-base:
|
||||||
|
build:
|
||||||
|
context: ./mcp-servers/knowledge-base
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
# KB_ prefix required by pydantic-settings config
|
||||||
|
- KB_HOST=0.0.0.0
|
||||||
|
- KB_PORT=8002
|
||||||
|
- KB_DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
|
||||||
|
- KB_REDIS_URL=redis://redis:6379/2
|
||||||
|
- KB_LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
|
- ENVIRONMENT=production
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
restart: unless-stopped
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cpus: '1.0'
|
||||||
|
memory: 1G
|
||||||
|
reservations:
|
||||||
|
cpus: '0.25'
|
||||||
|
memory: 256M
|
||||||
|
|
||||||
backend:
|
backend:
|
||||||
build:
|
build:
|
||||||
context: ./backend
|
context: ./backend
|
||||||
@@ -48,11 +124,18 @@ services:
|
|||||||
- DEBUG=false
|
- DEBUG=false
|
||||||
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
|
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
|
# MCP Server URLs
|
||||||
|
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
|
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
mcp-llm-gateway:
|
||||||
|
condition: service_healthy
|
||||||
|
mcp-knowledge-base:
|
||||||
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
@@ -75,11 +158,18 @@ services:
|
|||||||
- DATABASE_URL=${DATABASE_URL}
|
- DATABASE_URL=${DATABASE_URL}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
- CELERY_QUEUE=agent
|
- CELERY_QUEUE=agent
|
||||||
|
# MCP Server URLs (agents need access to MCP)
|
||||||
|
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
|
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
mcp-llm-gateway:
|
||||||
|
condition: service_healthy
|
||||||
|
mcp-knowledge-base:
|
||||||
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|||||||
@@ -205,6 +205,69 @@ test(frontend): add unit tests for ProjectDashboard
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Pre-Commit Hooks
|
||||||
|
|
||||||
|
The repository includes pre-commit hooks that enforce validation before commits on protected branches.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
Enable the hooks by configuring git to use the `.githooks` directory:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git config core.hooksPath .githooks
|
||||||
|
```
|
||||||
|
|
||||||
|
This only needs to be done once per clone.
|
||||||
|
|
||||||
|
### What the Hooks Do
|
||||||
|
|
||||||
|
When committing to **protected branches** (`main`, `dev`):
|
||||||
|
|
||||||
|
| Condition | Action |
|
||||||
|
|-----------|--------|
|
||||||
|
| Backend files changed | Runs `make validate` in `/backend` |
|
||||||
|
| Frontend files changed | Runs `npm run validate` in `/frontend` |
|
||||||
|
| No relevant changes | Skips validation |
|
||||||
|
|
||||||
|
If validation fails, the commit is blocked with an error message.
|
||||||
|
|
||||||
|
When committing to **feature branches**:
|
||||||
|
- Validation is skipped (allows WIP commits)
|
||||||
|
- A message reminds you to run validation manually if needed
|
||||||
|
|
||||||
|
### Why Protected Branches Only?
|
||||||
|
|
||||||
|
The hooks only enforce validation on `main` and `dev` for good reasons:
|
||||||
|
|
||||||
|
1. **Feature branches are for iteration** - WIP commits, experimentation, and rapid prototyping shouldn't be blocked
|
||||||
|
2. **Flexibility during development** - You can commit broken code to your feature branch while debugging
|
||||||
|
3. **PRs catch issues** - The merge process ensures validation passes before reaching protected branches
|
||||||
|
4. **Manual control** - You can always run `make validate` or `npm run validate` yourself
|
||||||
|
|
||||||
|
### Manual Validation
|
||||||
|
|
||||||
|
Even on feature branches, you should validate before creating a PR:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Backend
|
||||||
|
cd backend && make validate
|
||||||
|
|
||||||
|
# Frontend
|
||||||
|
cd frontend && npm run validate
|
||||||
|
```
|
||||||
|
|
||||||
|
### Bypassing Hooks (Emergency Only)
|
||||||
|
|
||||||
|
In rare cases where you need to bypass the hook:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git commit --no-verify -m "message"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Use sparingly** - this defeats the purpose of the hooks.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Documentation Updates
|
## Documentation Updates
|
||||||
|
|
||||||
- Keep `docs/architecture/IMPLEMENTATION_ROADMAP.md` updated
|
- Keep `docs/architecture/IMPLEMENTATION_ROADMAP.md` updated
|
||||||
@@ -314,8 +377,11 @@ Do NOT use parallel agents when:
|
|||||||
| Action | Command/Location |
|
| Action | Command/Location |
|
||||||
|--------|-----------------|
|
|--------|-----------------|
|
||||||
| Create branch | `git checkout -b feature/<issue>-<desc>` |
|
| Create branch | `git checkout -b feature/<issue>-<desc>` |
|
||||||
|
| Enable pre-commit hooks | `git config core.hooksPath .githooks` |
|
||||||
| Run backend tests | `IS_TEST=True uv run pytest` |
|
| Run backend tests | `IS_TEST=True uv run pytest` |
|
||||||
| Run frontend tests | `npm test` |
|
| Run frontend tests | `npm test` |
|
||||||
|
| Backend validation | `cd backend && make validate` |
|
||||||
|
| Frontend validation | `cd frontend && npm run validate` |
|
||||||
| Check types (backend) | `uv run mypy src/` |
|
| Check types (backend) | `uv run mypy src/` |
|
||||||
| Check types (frontend) | `npm run type-check` |
|
| Check types (frontend) | `npm run type-check` |
|
||||||
| Lint (backend) | `uv run ruff check src/` |
|
| Lint (backend) | `uv run ruff check src/` |
|
||||||
|
|||||||
@@ -386,10 +386,24 @@ describe('ActivityFeed', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('shows event count in group header', () => {
|
it('shows event count in group header', () => {
|
||||||
render(<ActivityFeed {...defaultProps} />);
|
// Create fresh "today" events to avoid timezone/day boundary issues
|
||||||
|
const todayEvents: ProjectEvent[] = [
|
||||||
|
createMockEvent({
|
||||||
|
id: 'today-event-1',
|
||||||
|
type: EventType.APPROVAL_REQUESTED,
|
||||||
|
timestamp: new Date().toISOString(),
|
||||||
|
}),
|
||||||
|
createMockEvent({
|
||||||
|
id: 'today-event-2',
|
||||||
|
type: EventType.AGENT_MESSAGE,
|
||||||
|
timestamp: new Date().toISOString(),
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<ActivityFeed {...defaultProps} events={todayEvents} />);
|
||||||
|
|
||||||
const todayGroup = screen.getByTestId('event-group-today');
|
const todayGroup = screen.getByTestId('event-group-today');
|
||||||
// Today has 2 events in our mock data
|
// Today has 2 events
|
||||||
expect(within(todayGroup).getByText('2')).toBeInTheDocument();
|
expect(within(todayGroup).getByText('2')).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
79
mcp-servers/knowledge-base/Makefile
Normal file
79
mcp-servers/knowledge-base/Makefile
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
|
||||||
|
|
||||||
|
# Default target
|
||||||
|
help:
|
||||||
|
@echo "Knowledge Base MCP Server - Development Commands"
|
||||||
|
@echo ""
|
||||||
|
@echo "Setup:"
|
||||||
|
@echo " make install - Install production dependencies"
|
||||||
|
@echo " make install-dev - Install development dependencies"
|
||||||
|
@echo ""
|
||||||
|
@echo "Quality Checks:"
|
||||||
|
@echo " make lint - Run Ruff linter"
|
||||||
|
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||||
|
@echo " make format - Format code with Ruff"
|
||||||
|
@echo " make type-check - Run mypy type checker"
|
||||||
|
@echo ""
|
||||||
|
@echo "Testing:"
|
||||||
|
@echo " make test - Run pytest"
|
||||||
|
@echo " make test-cov - Run pytest with coverage"
|
||||||
|
@echo ""
|
||||||
|
@echo "All-in-one:"
|
||||||
|
@echo " make validate - Run lint, type-check, and tests"
|
||||||
|
@echo ""
|
||||||
|
@echo "Running:"
|
||||||
|
@echo " make run - Run the server locally"
|
||||||
|
@echo ""
|
||||||
|
@echo "Cleanup:"
|
||||||
|
@echo " make clean - Remove cache and build artifacts"
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
install:
|
||||||
|
@echo "Installing production dependencies..."
|
||||||
|
@uv pip install -e .
|
||||||
|
|
||||||
|
install-dev:
|
||||||
|
@echo "Installing development dependencies..."
|
||||||
|
@uv pip install -e ".[dev]"
|
||||||
|
|
||||||
|
# Quality checks
|
||||||
|
lint:
|
||||||
|
@echo "Running Ruff linter..."
|
||||||
|
@uv run ruff check .
|
||||||
|
|
||||||
|
lint-fix:
|
||||||
|
@echo "Running Ruff linter with auto-fix..."
|
||||||
|
@uv run ruff check --fix .
|
||||||
|
|
||||||
|
format:
|
||||||
|
@echo "Formatting code..."
|
||||||
|
@uv run ruff format .
|
||||||
|
|
||||||
|
type-check:
|
||||||
|
@echo "Running mypy..."
|
||||||
|
@uv run mypy . --ignore-missing-imports
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
test:
|
||||||
|
@echo "Running tests..."
|
||||||
|
@uv run pytest tests/ -v
|
||||||
|
|
||||||
|
test-cov:
|
||||||
|
@echo "Running tests with coverage..."
|
||||||
|
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
|
||||||
|
|
||||||
|
# All-in-one validation
|
||||||
|
validate: lint type-check test
|
||||||
|
@echo "All validations passed!"
|
||||||
|
|
||||||
|
# Running
|
||||||
|
run:
|
||||||
|
@echo "Starting Knowledge Base server..."
|
||||||
|
@uv run python server.py
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
clean:
|
||||||
|
@echo "Cleaning up..."
|
||||||
|
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
|
||||||
|
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||||
|
@find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||||
@@ -328,7 +328,7 @@ class CollectionManager:
|
|||||||
"source_path": chunk.source_path or source_path,
|
"source_path": chunk.source_path or source_path,
|
||||||
"start_line": chunk.start_line,
|
"start_line": chunk.start_line,
|
||||||
"end_line": chunk.end_line,
|
"end_line": chunk.end_line,
|
||||||
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
|
"file_type": effective_file_type.value if (effective_file_type := chunk.file_type or file_type) else None,
|
||||||
}
|
}
|
||||||
embeddings_data.append((
|
embeddings_data.append((
|
||||||
chunk.content,
|
chunk.content,
|
||||||
|
|||||||
@@ -284,9 +284,8 @@ class DatabaseManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.acquire() as conn:
|
async with self.acquire() as conn, conn.transaction():
|
||||||
# Wrap in transaction for all-or-nothing batch semantics
|
# Wrap in transaction for all-or-nothing batch semantics
|
||||||
async with conn.transaction():
|
|
||||||
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
|
||||||
content_hash = self.compute_content_hash(content)
|
content_hash = self.compute_content_hash(content)
|
||||||
source_path = metadata.get("source_path")
|
source_path = metadata.get("source_path")
|
||||||
@@ -566,9 +565,8 @@ class DatabaseManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.acquire() as conn:
|
async with self.acquire() as conn, conn.transaction():
|
||||||
# Use transaction for atomic replace
|
# Use transaction for atomic replace
|
||||||
async with conn.transaction():
|
|
||||||
# First, delete existing embeddings for this source
|
# First, delete existing embeddings for this source
|
||||||
delete_result = await conn.execute(
|
delete_result = await conn.execute(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ async def health_check() -> dict[str, Any]:
|
|||||||
# Check Redis cache (non-critical - degraded without it)
|
# Check Redis cache (non-critical - degraded without it)
|
||||||
try:
|
try:
|
||||||
if _embeddings and _embeddings._redis:
|
if _embeddings and _embeddings._redis:
|
||||||
await _embeddings._redis.ping()
|
await _embeddings._redis.ping() # type: ignore[misc]
|
||||||
status["dependencies"]["redis"] = "connected"
|
status["dependencies"]["redis"] = "connected"
|
||||||
else:
|
else:
|
||||||
status["dependencies"]["redis"] = "not initialized"
|
status["dependencies"]["redis"] = "not initialized"
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Tests for server module and MCP tools."""
|
"""Tests for server module and MCP tools."""
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|||||||
@@ -1,39 +1,25 @@
|
|||||||
# Syndarix LLM Gateway MCP Server
|
# Syndarix LLM Gateway MCP Server
|
||||||
# Multi-stage build for minimal image size
|
FROM python:3.12-slim
|
||||||
|
|
||||||
# Build stage
|
WORKDIR /app
|
||||||
FROM python:3.12-slim AS builder
|
|
||||||
|
# Install system dependencies (needed for tiktoken regex compilation)
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install uv for fast package management
|
# Install uv for fast package management
|
||||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
|
||||||
|
|
||||||
WORKDIR /app
|
# Copy project files
|
||||||
|
|
||||||
# Copy dependency files
|
|
||||||
COPY pyproject.toml ./
|
COPY pyproject.toml ./
|
||||||
|
COPY *.py ./
|
||||||
|
|
||||||
# Create virtual environment and install dependencies
|
# Install dependencies to system Python
|
||||||
RUN uv venv /app/.venv
|
RUN uv pip install --system --no-cache .
|
||||||
ENV PATH="/app/.venv/bin:$PATH"
|
|
||||||
RUN uv pip install -e .
|
|
||||||
|
|
||||||
# Runtime stage
|
|
||||||
FROM python:3.12-slim AS runtime
|
|
||||||
|
|
||||||
# Create non-root user for security
|
# Create non-root user for security
|
||||||
RUN groupadd --gid 1000 appgroup && \
|
RUN useradd --create-home --shell /bin/bash appuser
|
||||||
useradd --uid 1000 --gid appgroup --shell /bin/bash --create-home appuser
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Copy virtual environment from builder
|
|
||||||
COPY --from=builder /app/.venv /app/.venv
|
|
||||||
ENV PATH="/app/.venv/bin:$PATH"
|
|
||||||
|
|
||||||
# Copy application code
|
|
||||||
COPY --chown=appuser:appgroup . .
|
|
||||||
|
|
||||||
# Switch to non-root user
|
|
||||||
USER appuser
|
USER appuser
|
||||||
|
|
||||||
# Environment variables
|
# Environment variables
|
||||||
@@ -47,7 +33,7 @@ EXPOSE 8001
|
|||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
CMD python -c "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()" || exit 1
|
CMD python -c "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"
|
||||||
|
|
||||||
# Run the server
|
# Run the server
|
||||||
CMD ["python", "server.py"]
|
CMD ["python", "server.py"]
|
||||||
|
|||||||
79
mcp-servers/llm-gateway/Makefile
Normal file
79
mcp-servers/llm-gateway/Makefile
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
|
||||||
|
|
||||||
|
# Default target
|
||||||
|
help:
|
||||||
|
@echo "LLM Gateway MCP Server - Development Commands"
|
||||||
|
@echo ""
|
||||||
|
@echo "Setup:"
|
||||||
|
@echo " make install - Install production dependencies"
|
||||||
|
@echo " make install-dev - Install development dependencies"
|
||||||
|
@echo ""
|
||||||
|
@echo "Quality Checks:"
|
||||||
|
@echo " make lint - Run Ruff linter"
|
||||||
|
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||||
|
@echo " make format - Format code with Ruff"
|
||||||
|
@echo " make type-check - Run mypy type checker"
|
||||||
|
@echo ""
|
||||||
|
@echo "Testing:"
|
||||||
|
@echo " make test - Run pytest"
|
||||||
|
@echo " make test-cov - Run pytest with coverage"
|
||||||
|
@echo ""
|
||||||
|
@echo "All-in-one:"
|
||||||
|
@echo " make validate - Run lint, type-check, and tests"
|
||||||
|
@echo ""
|
||||||
|
@echo "Running:"
|
||||||
|
@echo " make run - Run the server locally"
|
||||||
|
@echo ""
|
||||||
|
@echo "Cleanup:"
|
||||||
|
@echo " make clean - Remove cache and build artifacts"
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
install:
|
||||||
|
@echo "Installing production dependencies..."
|
||||||
|
@uv pip install -e .
|
||||||
|
|
||||||
|
install-dev:
|
||||||
|
@echo "Installing development dependencies..."
|
||||||
|
@uv pip install -e ".[dev]"
|
||||||
|
|
||||||
|
# Quality checks
|
||||||
|
lint:
|
||||||
|
@echo "Running Ruff linter..."
|
||||||
|
@uv run ruff check .
|
||||||
|
|
||||||
|
lint-fix:
|
||||||
|
@echo "Running Ruff linter with auto-fix..."
|
||||||
|
@uv run ruff check --fix .
|
||||||
|
|
||||||
|
format:
|
||||||
|
@echo "Formatting code..."
|
||||||
|
@uv run ruff format .
|
||||||
|
|
||||||
|
type-check:
|
||||||
|
@echo "Running mypy..."
|
||||||
|
@uv run mypy . --ignore-missing-imports
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
test:
|
||||||
|
@echo "Running tests..."
|
||||||
|
@uv run pytest tests/ -v
|
||||||
|
|
||||||
|
test-cov:
|
||||||
|
@echo "Running tests with coverage..."
|
||||||
|
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
|
||||||
|
|
||||||
|
# All-in-one validation
|
||||||
|
validate: lint type-check test
|
||||||
|
@echo "All validations passed!"
|
||||||
|
|
||||||
|
# Running
|
||||||
|
run:
|
||||||
|
@echo "Starting LLM Gateway server..."
|
||||||
|
@uv run python server.py
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
clean:
|
||||||
|
@echo "Cleaning up..."
|
||||||
|
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
|
||||||
|
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||||
|
@find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||||
@@ -110,9 +110,8 @@ class CircuitBreaker:
|
|||||||
"""
|
"""
|
||||||
if self._state == CircuitState.OPEN:
|
if self._state == CircuitState.OPEN:
|
||||||
time_in_open = time.time() - self._stats.state_changed_at
|
time_in_open = time.time() - self._stats.state_changed_at
|
||||||
if time_in_open >= self.recovery_timeout:
|
# Double-check state after time calculation (for thread safety)
|
||||||
# Only transition if still in OPEN state (double-check)
|
if time_in_open >= self.recovery_timeout and self._state == CircuitState.OPEN:
|
||||||
if self._state == CircuitState.OPEN:
|
|
||||||
self._transition_to(CircuitState.HALF_OPEN)
|
self._transition_to(CircuitState.HALF_OPEN)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Circuit {self.name} transitioned to HALF_OPEN "
|
f"Circuit {self.name} transitioned to HALF_OPEN "
|
||||||
|
|||||||
Reference in New Issue
Block a user