14 Commits

Author SHA1 Message Date
Felipe Cardoso
4b149b8a52 feat(tests): add unit tests for Context Management API routes
- Added detailed unit tests for `/context` endpoints, covering health checks, context assembly, token counting, budget retrieval, and cache invalidation.
- Included edge cases, error handling, and input validation for context-related operations.
- Improved test coverage for the Context Management module with mocked dependencies and integration scenarios.
2026-01-05 01:02:49 +01:00
Felipe Cardoso
ad0c06851d feat(tests): add comprehensive E2E tests for MCP and Agent workflows
- Introduced end-to-end tests for MCP workflows, including server discovery, authentication, context engine operations, error handling, and input validation.
- Added full lifecycle tests for agent workflows, covering type management, instance spawning, status transitions, and admin-only operations.
- Enhanced test coverage for real-world MCP and Agent scenarios across PostgreSQL and async environments.
2026-01-05 01:02:41 +01:00
Felipe Cardoso
49359b1416 feat(api): add Context Management API and routes
- Introduced a new `context` module and its endpoints for Context Management.
- Added `/context` route to the API router for assembling LLM context, token counting, budget management, and cache invalidation.
- Implemented health checks, context assembly, token counting, and caching operations in the Context Management Engine.
- Included schemas for request/response models and tightened error handling for context-related operations.
2026-01-05 01:02:33 +01:00
Felipe Cardoso
911d950c15 feat(tests): add comprehensive integration tests for MCP stack
- Introduced integration tests covering backend, LLM Gateway, Knowledge Base, and Context Engine.
- Includes health checks, tool listing, token counting, and end-to-end MCP flows.
- Added `RUN_INTEGRATION_TESTS` environment flag to enable selective test execution.
- Includes a quick health check script to verify service availability before running tests.
2026-01-05 01:02:22 +01:00
Felipe Cardoso
b2a3ac60e0 feat: add integration testing target to Makefile
- Introduced `test-integration` command for MCP integration tests.
- Expanded help section with details about running integration tests.
- Improved Makefile's testing capabilities for enhanced developer workflows.
2026-01-05 01:02:16 +01:00
Felipe Cardoso
dea092e1bb feat: extend Makefile with testing and validation commands, expand help section
- Added new targets for testing (`test`, `test-backend`, `test-mcp`, `test-frontend`, etc.) and validation (`validate`, `validate-all`).
- Enhanced help section to reflect updates, including detailed descriptions for testing, validation, and new MCP-specific commands.
- Improved developer workflow by centralizing testing and linting processes in the Makefile.
2026-01-05 01:02:09 +01:00
Felipe Cardoso
4154dd5268 feat: enhance database transactions, add Makefiles, and improve Docker setup
- Refactored database batch operations to ensure transaction atomicity and simplify nested structure.
- Added `Makefile` for `knowledge-base` and `llm-gateway` modules to streamline development workflows.
- Simplified `Dockerfile` for `llm-gateway` by removing multi-stage builds and optimizing dependencies.
- Improved code readability in `collection_manager` and `failover` modules with refined logic.
- Minor fixes in `test_server` and Redis health check handling for better diagnostics.
2026-01-05 00:49:19 +01:00
Felipe Cardoso
db12937495 feat: integrate MCP servers into Docker Compose files for development and deployment
- Added `mcp-llm-gateway` and `mcp-knowledge-base` services to `docker-compose.dev.yml`, `docker-compose.deploy.yml`, and `docker-compose.yml` for AI agent capabilities.
- Configured health checks, environment variables, and dependencies for MCP services.
- Included updated resource limits and deployment settings for production environments.
- Connected backend and agent services to the MCP servers.
2026-01-05 00:49:10 +01:00
Felipe Cardoso
81e1456631 test(activity): fix flaky test by generating fresh events for today group
- Resolves timezone and day boundary issues by creating fresh "today" events in the test case.
2026-01-05 00:30:36 +01:00
Felipe Cardoso
58e78d8700 docs(workflow): add pre-commit hooks documentation
Document the pre-commit hook setup, behavior, and rationale for
protecting only main/dev branches while allowing flexibility on
feature branches.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 19:49:45 +01:00
Felipe Cardoso
5e80139afa chore: add pre-commit hook for protected branch validation
Adds a git hook that:
- Blocks commits to main/dev if validation fails
- Runs `make validate` for backend changes
- Runs `npm run validate` for frontend changes
- Skips validation for feature branches (can run manually)

To enable: git config core.hooksPath .githooks

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 19:42:53 +01:00
Felipe Cardoso
60ebeaa582 test(safety): add comprehensive tests for safety framework modules
Add tests to improve backend coverage from 85% to 93%:

- test_audit.py: 60 tests for AuditLogger (20% -> 99%)
  - Hash chain integrity, sanitization, retention, handlers
  - Fixed bug: hash chain modification after event creation
  - Fixed bug: verification not using correct prev_hash

- test_hitl.py: Tests for HITL manager (0% -> 100%)
- test_permissions.py: Tests for permissions manager (0% -> 99%)
- test_rollback.py: Tests for rollback manager (0% -> 100%)
- test_metrics.py: Tests for metrics collector (0% -> 100%)
- test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%)
- test_validation.py: Additional cache and edge case tests (76% -> 100%)
- test_scoring.py: Lock cleanup and edge case tests (78% -> 91%)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-04 19:41:54 +01:00
Felipe Cardoso
758052dcff feat(context): improve budget validation and XML safety in ranking and Claude adapter
- Added stricter budget validation in ContextRanker with explicit error handling for invalid configurations.
- Introduced `_get_valid_token_count()` helper to validate and safeguard token counts.
- Enhanced XML escaping in Claude adapter to prevent injection risks from scores and unhandled content.
2026-01-04 16:02:18 +01:00
Felipe Cardoso
1628eacf2b feat(context): enhance timeout handling, tenant isolation, and budget management
- Added timeout enforcement for token counting, scoring, and compression with detailed error handling.
- Introduced tenant isolation in context caching using project and agent identifiers.
- Enhanced budget management with stricter checks for critical context overspending and buffer limitations.
- Optimized per-context locking with cleanup to prevent memory leaks in concurrent environments.
- Updated default assembly timeout settings for improved performance and reliability.
- Improved XML escaping in Claude adapter for safety against injection attacks.
- Standardized token estimation using model-specific ratios.
2026-01-04 15:52:50 +01:00
43 changed files with 10099 additions and 314 deletions

61
.githooks/pre-commit Executable file
View File

@@ -0,0 +1,61 @@
#!/bin/bash
# Pre-commit hook to enforce validation before commits on protected branches
# Install: git config core.hooksPath .githooks
set -e
# Get the current branch name
BRANCH=$(git rev-parse --abbrev-ref HEAD)
# Protected branches that require validation
PROTECTED_BRANCHES="main dev"
# Check if we're on a protected branch
is_protected() {
for branch in $PROTECTED_BRANCHES; do
if [ "$BRANCH" = "$branch" ]; then
return 0
fi
done
return 1
}
if is_protected; then
echo "🔒 Committing to protected branch '$BRANCH' - running validation..."
# Check if we have backend changes
if git diff --cached --name-only | grep -q "^backend/"; then
echo "📦 Backend changes detected - running make validate..."
cd backend
if ! make validate; then
echo ""
echo "❌ Backend validation failed!"
echo " Please fix the issues and try again."
echo " Run 'cd backend && make validate' to see errors."
exit 1
fi
cd ..
echo "✅ Backend validation passed!"
fi
# Check if we have frontend changes
if git diff --cached --name-only | grep -q "^frontend/"; then
echo "🎨 Frontend changes detected - running npm run validate..."
cd frontend
if ! npm run validate 2>/dev/null; then
echo ""
echo "❌ Frontend validation failed!"
echo " Please fix the issues and try again."
echo " Run 'cd frontend && npm run validate' to see errors."
exit 1
fi
cd ..
echo "✅ Frontend validation passed!"
fi
echo "🎉 All validations passed! Proceeding with commit..."
else
echo "📝 Committing to feature branch '$BRANCH' - skipping validation (run manually if needed)"
fi
exit 0

View File

@@ -1,18 +1,31 @@
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
VERSION ?= latest
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
# Default target
help:
@echo "FastAPI + Next.js Full-Stack Template"
@echo "Syndarix - AI-Powered Software Consulting Agency"
@echo ""
@echo "Development:"
@echo " make dev - Start backend + db (frontend runs separately)"
@echo " make dev - Start backend + db + MCP servers (frontend runs separately)"
@echo " make dev-full - Start all services including frontend"
@echo " make down - Stop all services"
@echo " make logs-dev - Follow dev container logs"
@echo ""
@echo "Testing:"
@echo " make test - Run all tests (backend + MCP servers)"
@echo " make test-backend - Run backend tests only"
@echo " make test-mcp - Run MCP server tests only"
@echo " make test-frontend - Run frontend tests only"
@echo " make test-cov - Run all tests with coverage reports"
@echo " make test-integration - Run MCP integration tests (requires running stack)"
@echo ""
@echo "Validation:"
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
@echo " make validate-all - Validate everything including frontend"
@echo ""
@echo "Database:"
@echo " make drop-db - Drop and recreate empty database"
@echo " make reset-db - Drop database and apply all migrations"
@@ -28,8 +41,10 @@ help:
@echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)"
@echo ""
@echo "Subdirectory commands:"
@echo " cd backend && make help - Backend-specific commands"
@echo " cd frontend && npm run - Frontend-specific commands"
@echo " cd backend && make help - Backend-specific commands"
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
@echo " cd frontend && npm run - Frontend-specific commands"
# ============================================================================
# Development
@@ -99,3 +114,72 @@ clean:
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
clean-slate:
docker compose -f docker-compose.dev.yml down -v --remove-orphans
# ============================================================================
# Testing
# ============================================================================
test: test-backend test-mcp
@echo ""
@echo "All tests passed!"
test-backend:
@echo "Running backend tests..."
@cd backend && IS_TEST=True uv run pytest tests/ -v
test-mcp:
@echo "Running MCP server tests..."
@echo ""
@echo "=== LLM Gateway ==="
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v
@echo ""
@echo "=== Knowledge Base ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
test-frontend:
@echo "Running frontend tests..."
@cd frontend && npm test
test-all: test test-frontend
@echo ""
@echo "All tests (backend + MCP + frontend) passed!"
test-cov:
@echo "Running all tests with coverage..."
@echo ""
@echo "=== Backend Coverage ==="
@cd backend && IS_TEST=True uv run pytest tests/ -v --cov=app --cov-report=term-missing
@echo ""
@echo "=== LLM Gateway Coverage ==="
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v --cov=. --cov-report=term-missing
@echo ""
@echo "=== Knowledge Base Coverage ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
test-integration:
@echo "Running MCP integration tests..."
@echo "Note: Requires running stack (make dev first)"
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
# ============================================================================
# Validation (lint + type-check + test)
# ============================================================================
validate:
@echo "Validating backend..."
@cd backend && make validate
@echo ""
@echo "Validating LLM Gateway..."
@cd mcp-servers/llm-gateway && make validate
@echo ""
@echo "Validating Knowledge Base..."
@cd mcp-servers/knowledge-base && make validate
@echo ""
@echo "All validations passed!"
validate-all: validate
@echo ""
@echo "Validating frontend..."
@cd frontend && npm run validate
@echo ""
@echo "Full validation passed!"

View File

@@ -1,4 +1,4 @@
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all test-integration
# Default target
help:
@@ -22,6 +22,7 @@ help:
@echo " make test-cov - Run pytest with coverage report"
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
@echo " make test-integration - Run MCP integration tests (requires running stack)"
@echo " make test-all - Run all tests (unit + E2E)"
@echo " make check-docker - Check if Docker is available"
@echo ""
@@ -82,6 +83,15 @@ test-cov:
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
@echo "📊 Coverage report generated in htmlcov/index.html"
# ============================================================================
# Integration Testing (requires running stack: make dev)
# ============================================================================
test-integration:
@echo "🧪 Running MCP integration tests..."
@echo "Note: Requires running stack (make dev from project root)"
@RUN_INTEGRATION_TESTS=true IS_TEST=True PYTHONPATH=. uv run pytest tests/integration/ -v
# ============================================================================
# E2E Testing (requires Docker)
# ============================================================================

View File

@@ -5,6 +5,7 @@ from app.api.routes import (
agent_types,
agents,
auth,
context,
events,
issues,
mcp,
@@ -35,6 +36,9 @@ api_router.include_router(events.router, tags=["Events"])
# MCP (Model Context Protocol) router
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
# Context Management Engine router
api_router.include_router(context.router, prefix="/context", tags=["Context"])
# Syndarix domain routers
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
api_router.include_router(

View File

@@ -0,0 +1,411 @@
"""
Context Management API Endpoints.
Provides REST endpoints for context assembly and optimization
for LLM requests using the ContextEngine.
"""
import logging
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from app.api.dependencies.permissions import require_superuser
from app.models.user import User
from app.services.context import (
AssemblyTimeoutError,
BudgetExceededError,
ContextEngine,
ContextSettings,
create_context_engine,
get_context_settings,
)
from app.services.mcp import MCPClientManager, get_mcp_client
logger = logging.getLogger(__name__)
router = APIRouter()
# ============================================================================
# Singleton Engine Management
# ============================================================================
_context_engine: ContextEngine | None = None
def _get_or_create_engine(
mcp: MCPClientManager,
settings: ContextSettings | None = None,
) -> ContextEngine:
"""Get or create the singleton ContextEngine."""
global _context_engine
if _context_engine is None:
_context_engine = create_context_engine(
mcp_manager=mcp,
redis=None, # Optional: add Redis caching later
settings=settings or get_context_settings(),
)
logger.info("ContextEngine initialized")
else:
# Ensure MCP manager is up to date
_context_engine.set_mcp_manager(mcp)
return _context_engine
async def get_context_engine(
mcp: MCPClientManager = Depends(get_mcp_client),
) -> ContextEngine:
"""FastAPI dependency to get the ContextEngine."""
return _get_or_create_engine(mcp)
# ============================================================================
# Request/Response Schemas
# ============================================================================
class ConversationTurn(BaseModel):
"""A single conversation turn."""
role: str = Field(..., description="Role: 'user' or 'assistant'")
content: str = Field(..., description="Message content")
class ToolResult(BaseModel):
"""A tool execution result."""
tool_name: str = Field(..., description="Name of the tool")
content: str | dict[str, Any] = Field(..., description="Tool result content")
status: str = Field(default="success", description="Execution status")
class AssembleContextRequest(BaseModel):
"""Request to assemble context for an LLM request."""
project_id: str = Field(..., description="Project identifier")
agent_id: str = Field(..., description="Agent identifier")
query: str = Field(..., description="User's query or current request")
model: str = Field(
default="claude-3-sonnet",
description="Target model name",
)
max_tokens: int | None = Field(
None,
description="Maximum context tokens (uses model default if None)",
)
system_prompt: str | None = Field(
None,
description="System prompt/instructions",
)
task_description: str | None = Field(
None,
description="Current task description",
)
knowledge_query: str | None = Field(
None,
description="Query for knowledge base search",
)
knowledge_limit: int = Field(
default=10,
ge=1,
le=50,
description="Max number of knowledge results",
)
conversation_history: list[ConversationTurn] | None = Field(
None,
description="Previous conversation turns",
)
tool_results: list[ToolResult] | None = Field(
None,
description="Tool execution results to include",
)
compress: bool = Field(
default=True,
description="Whether to apply compression",
)
use_cache: bool = Field(
default=True,
description="Whether to use caching",
)
class AssembledContextResponse(BaseModel):
"""Response containing assembled context."""
content: str = Field(..., description="Assembled context content")
total_tokens: int = Field(..., description="Total token count")
context_count: int = Field(..., description="Number of context items included")
compressed: bool = Field(..., description="Whether compression was applied")
budget_used_percent: float = Field(
...,
description="Percentage of token budget used",
)
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Additional metadata",
)
class TokenCountRequest(BaseModel):
"""Request to count tokens in content."""
content: str = Field(..., description="Content to count tokens in")
model: str | None = Field(
None,
description="Model for model-specific tokenization",
)
class TokenCountResponse(BaseModel):
"""Response containing token count."""
token_count: int = Field(..., description="Number of tokens")
model: str | None = Field(None, description="Model used for counting")
class BudgetInfoResponse(BaseModel):
"""Response containing budget information for a model."""
model: str = Field(..., description="Model name")
total_tokens: int = Field(..., description="Total token budget")
system_tokens: int = Field(..., description="Tokens reserved for system")
knowledge_tokens: int = Field(..., description="Tokens for knowledge")
conversation_tokens: int = Field(..., description="Tokens for conversation")
tool_tokens: int = Field(..., description="Tokens for tool results")
response_reserve: int = Field(..., description="Tokens reserved for response")
class ContextEngineStatsResponse(BaseModel):
"""Response containing engine statistics."""
cache: dict[str, Any] = Field(..., description="Cache statistics")
settings: dict[str, Any] = Field(..., description="Current settings")
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Health status")
mcp_connected: bool = Field(..., description="Whether MCP is connected")
cache_enabled: bool = Field(..., description="Whether caching is enabled")
# ============================================================================
# Endpoints
# ============================================================================
@router.get(
"/health",
response_model=HealthResponse,
summary="Context Engine Health",
description="Check health status of the context engine.",
)
async def health_check(
engine: ContextEngine = Depends(get_context_engine),
) -> HealthResponse:
"""Check context engine health."""
stats = await engine.get_stats()
return HealthResponse(
status="healthy",
mcp_connected=engine._mcp is not None,
cache_enabled=stats.get("settings", {}).get("cache_enabled", False),
)
@router.post(
"/assemble",
response_model=AssembledContextResponse,
summary="Assemble Context",
description="Assemble optimized context for an LLM request.",
)
async def assemble_context(
request: AssembleContextRequest,
current_user: User = Depends(require_superuser),
engine: ContextEngine = Depends(get_context_engine),
) -> AssembledContextResponse:
"""
Assemble optimized context for an LLM request.
This endpoint gathers context from various sources, scores and ranks them,
compresses if needed, and formats for the target model.
"""
logger.info(
"Context assembly for project=%s agent=%s by user=%s",
request.project_id,
request.agent_id,
current_user.id,
)
# Convert conversation history to dict format
conversation_history = None
if request.conversation_history:
conversation_history = [
{"role": turn.role, "content": turn.content}
for turn in request.conversation_history
]
# Convert tool results to dict format
tool_results = None
if request.tool_results:
tool_results = [
{
"tool_name": tr.tool_name,
"content": tr.content,
"status": tr.status,
}
for tr in request.tool_results
]
try:
result = await engine.assemble_context(
project_id=request.project_id,
agent_id=request.agent_id,
query=request.query,
model=request.model,
max_tokens=request.max_tokens,
system_prompt=request.system_prompt,
task_description=request.task_description,
knowledge_query=request.knowledge_query,
knowledge_limit=request.knowledge_limit,
conversation_history=conversation_history,
tool_results=tool_results,
compress=request.compress,
use_cache=request.use_cache,
)
# Calculate budget usage percentage
budget = await engine.get_budget_for_model(request.model, request.max_tokens)
budget_used_percent = (result.total_tokens / budget.total) * 100
# Check if compression was applied (from metadata if available)
was_compressed = result.metadata.get("compressed_contexts", 0) > 0
return AssembledContextResponse(
content=result.content,
total_tokens=result.total_tokens,
context_count=result.context_count,
compressed=was_compressed,
budget_used_percent=round(budget_used_percent, 2),
metadata={
"model": request.model,
"query": request.query,
"knowledge_included": bool(request.knowledge_query),
"conversation_turns": len(request.conversation_history or []),
"excluded_count": result.excluded_count,
"assembly_time_ms": result.assembly_time_ms,
},
)
except AssemblyTimeoutError as e:
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail=f"Context assembly timed out: {e}",
) from e
except BudgetExceededError as e:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"Token budget exceeded: {e}",
) from e
except Exception as e:
logger.exception("Context assembly failed")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Context assembly failed: {e}",
) from e
@router.post(
"/count-tokens",
response_model=TokenCountResponse,
summary="Count Tokens",
description="Count tokens in content using the LLM Gateway.",
)
async def count_tokens(
request: TokenCountRequest,
engine: ContextEngine = Depends(get_context_engine),
) -> TokenCountResponse:
"""Count tokens in content."""
try:
count = await engine.count_tokens(
content=request.content,
model=request.model,
)
return TokenCountResponse(
token_count=count,
model=request.model,
)
except Exception as e:
logger.warning(f"Token counting failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token counting failed: {e}",
) from e
@router.get(
"/budget/{model}",
response_model=BudgetInfoResponse,
summary="Get Token Budget",
description="Get token budget allocation for a specific model.",
)
async def get_budget(
model: str,
max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None,
engine: ContextEngine = Depends(get_context_engine),
) -> BudgetInfoResponse:
"""Get token budget information for a model."""
budget = await engine.get_budget_for_model(model, max_tokens)
return BudgetInfoResponse(
model=model,
total_tokens=budget.total,
system_tokens=budget.system,
knowledge_tokens=budget.knowledge,
conversation_tokens=budget.conversation,
tool_tokens=budget.tools,
response_reserve=budget.response_reserve,
)
@router.get(
"/stats",
response_model=ContextEngineStatsResponse,
summary="Engine Statistics",
description="Get context engine statistics and configuration.",
)
async def get_stats(
current_user: User = Depends(require_superuser),
engine: ContextEngine = Depends(get_context_engine),
) -> ContextEngineStatsResponse:
"""Get engine statistics."""
stats = await engine.get_stats()
return ContextEngineStatsResponse(
cache=stats.get("cache", {}),
settings=stats.get("settings", {}),
)
@router.post(
"/cache/invalidate",
status_code=status.HTTP_204_NO_CONTENT,
summary="Invalidate Cache (Admin Only)",
description="Invalidate context cache entries.",
)
async def invalidate_cache(
project_id: Annotated[
str | None, Query(description="Project to invalidate")
] = None,
pattern: Annotated[str | None, Query(description="Pattern to match")] = None,
current_user: User = Depends(require_superuser),
engine: ContextEngine = Depends(get_context_engine),
) -> None:
"""Invalidate cache entries."""
logger.info(
"Cache invalidation by user %s: project=%s pattern=%s",
current_user.id,
project_id,
pattern,
)
await engine.invalidate_cache(project_id=project_id, pattern=pattern)

View File

@@ -90,16 +90,19 @@ class ClaudeAdapter(ModelAdapter):
elif context_type == ContextType.TOOL:
return self._format_tool(contexts)
return "\n".join(c.content for c in contexts)
# Fallback for any unhandled context types - still escape content
# to prevent XML injection if new types are added without updating adapter
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
def _format_system(self, contexts: list[BaseContext]) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
# System prompts are typically admin-controlled, but escape for safety
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
return f"<system_instructions>\n{content}\n</system_instructions>"
def _format_task(self, contexts: list[BaseContext]) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
return f"<current_task>\n{content}\n</current_task>"
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
@@ -107,16 +110,22 @@ class ClaudeAdapter(ModelAdapter):
Format knowledge contexts as structured documents.
Each knowledge context becomes a document with source attribution.
All content is XML-escaped to prevent injection attacks.
"""
parts = ["<reference_documents>"]
for ctx in contexts:
source = self._escape_xml(ctx.source)
content = ctx.content
# Escape content to prevent XML injection
content = self._escape_xml_content(ctx.content)
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
if score:
parts.append(f'<document source="{source}" relevance="{score}">')
# Escape score to prevent XML injection via metadata
escaped_score = self._escape_xml(str(score))
parts.append(
f'<document source="{source}" relevance="{escaped_score}">'
)
else:
parts.append(f'<document source="{source}">')
@@ -131,13 +140,16 @@ class ClaudeAdapter(ModelAdapter):
Format conversation contexts as message history.
Uses role-based message tags for clear turn delineation.
All content is XML-escaped to prevent prompt injection.
"""
parts = ["<conversation_history>"]
for ctx in contexts:
role = ctx.metadata.get("role", "user")
role = self._escape_xml(ctx.metadata.get("role", "user"))
# Escape content to prevent prompt injection via fake XML tags
content = self._escape_xml_content(ctx.content)
parts.append(f'<message role="{role}">')
parts.append(ctx.content)
parts.append(content)
parts.append("</message>")
parts.append("</conversation_history>")
@@ -148,19 +160,23 @@ class ClaudeAdapter(ModelAdapter):
Format tool contexts as tool results.
Each tool result is wrapped with the tool name.
All content is XML-escaped to prevent injection.
"""
parts = ["<tool_results>"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
status = ctx.metadata.get("status", "")
if status:
parts.append(f'<tool_result name="{tool_name}" status="{status}">')
parts.append(
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
)
else:
parts.append(f'<tool_result name="{tool_name}">')
parts.append(ctx.content)
# Escape content to prevent injection
parts.append(self._escape_xml_content(ctx.content))
parts.append("</tool_result>")
parts.append("</tool_results>")
@@ -176,3 +192,21 @@ class ClaudeAdapter(ModelAdapter):
.replace('"', "&quot;")
.replace("'", "&apos;")
)
@staticmethod
def _escape_xml_content(text: str) -> str:
"""
Escape XML special characters in element content.
This prevents XML injection attacks where malicious content
could break out of XML tags or inject fake tags for prompt injection.
Only escapes &, <, > since quotes don't need escaping in content.
Args:
text: Content text to escape
Returns:
XML-safe content string
"""
return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")

View File

@@ -12,6 +12,7 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from ..adapters import get_adapter
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
from ..compression.truncation import ContextCompressor
from ..config import ContextSettings, get_context_settings
@@ -156,20 +157,42 @@ class ContextPipeline:
else:
budget = self._allocator.create_budget_for_model(model)
# 1. Count tokens for all contexts
await self._ensure_token_counts(contexts, model)
# 1. Count tokens for all contexts (with timeout enforcement)
try:
await asyncio.wait_for(
self._ensure_token_counts(contexts, model),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during token counting",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
# Check timeout
# Check timeout (handles edge case where operation finished just at limit)
self._check_timeout(start, timeout, "token counting")
# 2. Score and rank contexts
# 2. Score and rank contexts (with timeout enforcement)
scoring_start = time.perf_counter()
ranking_result = await self._ranker.rank(
contexts=contexts,
query=query,
budget=budget,
model=model,
)
try:
ranking_result = await asyncio.wait_for(
self._ranker.rank(
contexts=contexts,
query=query,
budget=budget,
model=model,
),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during scoring/ranking",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
selected_contexts = ranking_result.selected_contexts
@@ -179,12 +202,23 @@ class ContextPipeline:
# Check timeout
self._check_timeout(start, timeout, "scoring")
# 3. Compress if needed and enabled
# 3. Compress if needed and enabled (with timeout enforcement)
if compress and self._needs_compression(selected_contexts, budget):
compression_start = time.perf_counter()
selected_contexts = await self._compressor.compress_contexts(
selected_contexts, budget, model
)
try:
selected_contexts = await asyncio.wait_for(
self._compressor.compress_contexts(
selected_contexts, budget, model
),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during compression",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
metrics.compression_time_ms = (
time.perf_counter() - compression_start
) * 1000
@@ -280,129 +314,18 @@ class ContextPipeline:
"""
Format contexts for the target model.
Groups contexts by type and applies model-specific formatting.
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
to format contexts optimally for each model family.
Args:
contexts: Contexts to format
model: Target model name
Returns:
Formatted context string
"""
# Group by type
by_type: dict[ContextType, list[BaseContext]] = {}
for context in contexts:
ct = context.get_type()
if ct not in by_type:
by_type[ct] = []
by_type[ct].append(context)
# Order types: System -> Task -> Knowledge -> Conversation -> Tool
type_order = [
ContextType.SYSTEM,
ContextType.TASK,
ContextType.KNOWLEDGE,
ContextType.CONVERSATION,
ContextType.TOOL,
]
parts: list[str] = []
for ct in type_order:
if ct in by_type:
formatted = self._format_type(by_type[ct], ct, model)
if formatted:
parts.append(formatted)
return "\n\n".join(parts)
def _format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
model: str,
) -> str:
"""Format contexts of a specific type."""
if not contexts:
return ""
# Check if model prefers XML tags (Claude)
use_xml = "claude" in model.lower()
if context_type == ContextType.SYSTEM:
return self._format_system(contexts, use_xml)
elif context_type == ContextType.TASK:
return self._format_task(contexts, use_xml)
elif context_type == ContextType.KNOWLEDGE:
return self._format_knowledge(contexts, use_xml)
elif context_type == ContextType.CONVERSATION:
return self._format_conversation(contexts, use_xml)
elif context_type == ContextType.TOOL:
return self._format_tool(contexts, use_xml)
return "\n".join(c.content for c in contexts)
def _format_system(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<system_instructions>\n{content}\n</system_instructions>"
return content
def _format_task(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
if use_xml:
return f"<current_task>\n{content}\n</current_task>"
return f"## Current Task\n\n{content}"
def _format_knowledge(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format knowledge contexts."""
if use_xml:
parts = ["<reference_documents>"]
for ctx in contexts:
parts.append(f'<document source="{ctx.source}">')
parts.append(ctx.content)
parts.append("</document>")
parts.append("</reference_documents>")
return "\n".join(parts)
else:
parts = ["## Reference Documents\n"]
for ctx in contexts:
parts.append(f"### Source: {ctx.source}\n")
parts.append(ctx.content)
parts.append("")
return "\n".join(parts)
def _format_conversation(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format conversation contexts."""
if use_xml:
parts = ["<conversation_history>"]
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f'<message role="{role}">')
parts.append(ctx.content)
parts.append("</message>")
parts.append("</conversation_history>")
return "\n".join(parts)
else:
parts = []
for ctx in contexts:
role = ctx.metadata.get("role", "user")
parts.append(f"**{role.upper()}**: {ctx.content}")
return "\n\n".join(parts)
def _format_tool(self, contexts: list[BaseContext], use_xml: bool) -> str:
"""Format tool contexts."""
if use_xml:
parts = ["<tool_results>"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f'<tool_result name="{tool_name}">')
parts.append(ctx.content)
parts.append("</tool_result>")
parts.append("</tool_results>")
return "\n".join(parts)
else:
parts = ["## Recent Tool Results\n"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
parts.append(f"### Tool: {tool_name}\n")
parts.append(f"```\n{ctx.content}\n```")
parts.append("")
return "\n".join(parts)
adapter = get_adapter(model)
return adapter.format(contexts)
def _check_timeout(
self,
@@ -412,9 +335,28 @@ class ContextPipeline:
) -> None:
"""Check if timeout exceeded and raise if so."""
elapsed_ms = (time.perf_counter() - start) * 1000
if elapsed_ms > timeout_ms:
if elapsed_ms >= timeout_ms:
raise AssemblyTimeoutError(
message=f"Context assembly timed out during {phase}",
elapsed_ms=elapsed_ms,
timeout_ms=timeout_ms,
)
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
"""
Calculate remaining timeout in seconds for asyncio.wait_for.
Returns at least a small positive value to avoid immediate timeout
edge cases with wait_for.
Args:
start: Start time from time.perf_counter()
timeout_ms: Total timeout in milliseconds
Returns:
Remaining timeout in seconds (minimum 0.001)
"""
elapsed_ms = (time.perf_counter() - start) * 1000
remaining_ms = timeout_ms - elapsed_ms
# Return at least 1ms to avoid zero/negative timeout edge cases
return max(remaining_ms / 1000.0, 0.001)

View File

@@ -293,14 +293,18 @@ class BudgetAllocator:
if isinstance(context_type, ContextType):
context_type = context_type.value
# Calculate adjustment (limited by buffer)
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
if adjustment > 0:
# Taking from buffer
# Taking from buffer - limited by available buffer
actual_adjustment = min(adjustment, budget.buffer)
budget.buffer -= actual_adjustment
else:
# Returning to buffer
actual_adjustment = adjustment
# Returning to buffer - limited by current allocation of target type
current_allocation = budget.get_allocation(context_type)
# Can't return more than current allocation
actual_adjustment = max(adjustment, -current_allocation)
# Add returned tokens back to buffer (adjustment is negative, so subtract)
budget.buffer -= actual_adjustment
# Apply to target type
if context_type == "system":

View File

@@ -95,19 +95,28 @@ class ContextCache:
contexts: list[BaseContext],
query: str,
model: str,
project_id: str | None = None,
agent_id: str | None = None,
) -> str:
"""
Compute a fingerprint for a context assembly request.
The fingerprint is based on:
- Project and agent IDs (for tenant isolation)
- Context content hash and metadata (not full content for performance)
- Query string
- Target model
SECURITY: project_id and agent_id MUST be included to prevent
cross-tenant cache pollution. Without these, one tenant could
receive cached contexts from another tenant with the same query.
Args:
contexts: List of contexts
query: Query string
model: Model name
project_id: Project ID for tenant isolation
agent_id: Agent ID for tenant isolation
Returns:
32-character hex fingerprint
@@ -128,6 +137,9 @@ class ContextCache:
)
data = {
# CRITICAL: Include tenant identifiers for cache isolation
"project_id": project_id or "",
"agent_id": agent_id or "",
"contexts": context_data,
"query": query,
"model": model,

View File

@@ -19,6 +19,40 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _estimate_tokens(text: str, model: str | None = None) -> int:
"""
Estimate token count using model-specific character ratios.
Module-level function for reuse across classes. Uses the same ratios
as TokenCalculator for consistency.
Args:
text: Text to estimate tokens for
model: Optional model name for model-specific ratios
Returns:
Estimated token count (minimum 1)
"""
# Model-specific character ratios (chars per token)
model_ratios = {
"claude": 3.5,
"gpt-4": 4.0,
"gpt-3.5": 4.0,
"gemini": 4.0,
}
default_ratio = 4.0
ratio = default_ratio
if model:
model_lower = model.lower()
for model_prefix, model_ratio in model_ratios.items():
if model_prefix in model_lower:
ratio = model_ratio
break
return max(1, int(len(text) / ratio))
@dataclass
class TruncationResult:
"""Result of truncation operation."""
@@ -284,8 +318,8 @@ class TruncationStrategy:
if self._calculator is not None:
return await self._calculator.count_tokens(text, model)
# Fallback estimation
return max(1, len(text) // 4)
# Fallback estimation with model-specific ratios
return _estimate_tokens(text, model)
class ContextCompressor:
@@ -415,4 +449,5 @@ class ContextCompressor:
"""Count tokens using calculator or estimation."""
if self._calculator is not None:
return await self._calculator.count_tokens(text, model)
return max(1, len(text) // 4)
# Use model-specific estimation for consistency
return _estimate_tokens(text, model)

View File

@@ -149,10 +149,11 @@ class ContextSettings(BaseSettings):
# Performance settings
max_assembly_time_ms: int = Field(
default=100,
default=2000,
ge=10,
le=5000,
description="Maximum time for context assembly in milliseconds",
le=30000,
description="Maximum time for context assembly in milliseconds. "
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
)
parallel_scoring: bool = Field(
default=True,

View File

@@ -212,7 +212,10 @@ class ContextEngine:
# Check cache if enabled
fingerprint: str | None = None
if use_cache and self._cache.is_enabled:
fingerprint = self._cache.compute_fingerprint(contexts, query, model)
# Include project_id and agent_id for tenant isolation
fingerprint = self._cache.compute_fingerprint(
contexts, query, model, project_id=project_id, agent_id=agent_id
)
cached = await self._cache.get_assembled(fingerprint)
if cached:
logger.debug(f"Cache hit for context assembly: {fingerprint}")

View File

@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any
from ..budget import TokenBudget, TokenCalculator
from ..config import ContextSettings, get_context_settings
from ..exceptions import BudgetExceededError
from ..scoring.composite import CompositeScorer, ScoredContext
from ..types import BaseContext, ContextPriority
@@ -127,9 +128,25 @@ class ContextRanker:
excluded: list[ScoredContext] = []
total_tokens = 0
# Calculate the usable budget (total minus reserved portions)
usable_budget = budget.total - budget.response_reserve - budget.buffer
# Guard against invalid budget configuration
if usable_budget <= 0:
raise BudgetExceededError(
message=(
f"Invalid budget configuration: no usable tokens available. "
f"total={budget.total}, response_reserve={budget.response_reserve}, "
f"buffer={budget.buffer}"
),
allocated=budget.total,
requested=0,
context_type="CONFIGURATION_ERROR",
)
# First, try to fit required contexts
for sc in required:
token_count = sc.context.token_count or 0
token_count = self._get_valid_token_count(sc.context)
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
@@ -137,7 +154,20 @@ class ContextRanker:
selected.append(sc)
total_tokens += token_count
else:
# Force-fit CRITICAL contexts if needed
# Force-fit CRITICAL contexts if needed, but check total budget first
if total_tokens + token_count > usable_budget:
# Even CRITICAL contexts cannot exceed total model context window
raise BudgetExceededError(
message=(
f"CRITICAL contexts exceed total budget. "
f"Context '{sc.context.source}' ({token_count} tokens) "
f"would exceed usable budget of {usable_budget} tokens."
),
allocated=usable_budget,
requested=total_tokens + token_count,
context_type="CRITICAL_OVERFLOW",
)
budget.allocate(context_type, token_count, force=True)
selected.append(sc)
total_tokens += token_count
@@ -148,7 +178,7 @@ class ContextRanker:
# Then, greedily add optional contexts
for sc in optional:
token_count = sc.context.token_count or 0
token_count = self._get_valid_token_count(sc.context)
context_type = sc.context.get_type()
if budget.can_fit(context_type, token_count):
@@ -215,13 +245,43 @@ class ContextRanker:
total_tokens = 0
for sc in scored_contexts:
token_count = sc.context.token_count or 0
token_count = self._get_valid_token_count(sc.context)
if total_tokens + token_count <= max_tokens:
selected.append(sc.context)
total_tokens += token_count
return selected
def _get_valid_token_count(self, context: BaseContext) -> int:
"""
Get validated token count from a context.
Ensures token_count is set (not None) and non-negative to prevent
budget bypass attacks where:
- None would be treated as 0 (allowing huge contexts to slip through)
- Negative values would corrupt budget tracking
Args:
context: Context to get token count from
Returns:
Valid non-negative token count
Raises:
ValueError: If token_count is None or negative
"""
if context.token_count is None:
raise ValueError(
f"Context '{context.source}' has no token count. "
"Ensure _ensure_token_counts() is called before ranking."
)
if context.token_count < 0:
raise ValueError(
f"Context '{context.source}' has invalid negative token count: "
f"{context.token_count}"
)
return context.token_count
async def _ensure_token_counts(
self,
contexts: list[BaseContext],
@@ -266,6 +326,7 @@ class ContextRanker:
if type_name not in by_type:
by_type[type_name] = {"count": 0, "tokens": 0}
by_type[type_name]["count"] += 1
# Use validated token count (already validated during ranking)
by_type[type_name]["tokens"] += sc.context.token_count or 0
return by_type

View File

@@ -6,9 +6,9 @@ Combines multiple scoring strategies with configurable weights.
import asyncio
import logging
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary
from ..config import ContextSettings, get_context_settings
from ..types import BaseContext
@@ -91,11 +91,11 @@ class CompositeScorer:
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
# Per-context locks to prevent race conditions during parallel scoring
# Uses WeakValueDictionary so locks are garbage collected when not in use
self._context_locks: WeakValueDictionary[str, asyncio.Lock] = (
WeakValueDictionary()
)
# Uses dict with (lock, last_used_time) tuples for cleanup
self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""Set MCP manager for semantic scoring."""
@@ -141,7 +141,8 @@ class CompositeScorer:
Get or create a lock for a specific context.
Thread-safe access to per-context locks prevents race conditions
when the same context is scored concurrently.
when the same context is scored concurrently. Includes automatic
cleanup of old locks to prevent memory growth.
Args:
context_id: The context ID to get a lock for
@@ -149,25 +150,78 @@ class CompositeScorer:
Returns:
asyncio.Lock for the context
"""
now = time.time()
# Fast path: check if lock exists without acquiring main lock
if context_id in self._context_locks:
lock = self._context_locks.get(context_id)
if lock is not None:
# NOTE: We only READ here - no writes to avoid race conditions
# with cleanup. The timestamp will be updated in the slow path
# if the lock is still valid.
lock_entry = self._context_locks.get(context_id)
if lock_entry is not None:
lock, _ = lock_entry
# Return the lock but defer timestamp update to avoid race
# The lock is still valid; timestamp update is best-effort
return lock
# Slow path: create lock or update timestamp while holding main lock
async with self._locks_lock:
# Double-check after acquiring lock - entry may have been
# created by another coroutine or deleted by cleanup
lock_entry = self._context_locks.get(context_id)
if lock_entry is not None:
lock, _ = lock_entry
# Safe to update timestamp here since we hold the lock
self._context_locks[context_id] = (lock, now)
return lock
# Slow path: create lock while holding main lock
async with self._locks_lock:
# Double-check after acquiring lock
if context_id in self._context_locks:
lock = self._context_locks.get(context_id)
if lock is not None:
return lock
# Cleanup old locks if we have too many
if len(self._context_locks) >= self._max_locks:
self._cleanup_old_locks(now)
# Create new lock
new_lock = asyncio.Lock()
self._context_locks[context_id] = new_lock
self._context_locks[context_id] = (new_lock, now)
return new_lock
def _cleanup_old_locks(self, now: float) -> None:
"""
Remove old locks that haven't been used recently.
Called while holding _locks_lock. Removes locks older than _lock_ttl,
but only if they're not currently held.
Args:
now: Current timestamp for age calculation
"""
cutoff = now - self._lock_ttl
to_remove = []
for context_id, (lock, last_used) in self._context_locks.items():
# Only remove if old AND not currently held
if last_used < cutoff and not lock.locked():
to_remove.append(context_id)
# Remove oldest 50% if still over limit after TTL filtering
if len(self._context_locks) - len(to_remove) >= self._max_locks:
# Sort by last used time and mark oldest for removal
sorted_entries = sorted(
self._context_locks.items(),
key=lambda x: x[1][1], # Sort by last_used time
)
# Remove oldest 50% that aren't locked
target_remove = len(self._context_locks) // 2
for context_id, (lock, _) in sorted_entries:
if len(to_remove) >= target_remove:
break
if context_id not in to_remove and not lock.locked():
to_remove.append(context_id)
for context_id in to_remove:
del self._context_locks[context_id]
if to_remove:
logger.debug(f"Cleaned up {len(to_remove)} context locks")
async def score(
self,
context: BaseContext,

View File

@@ -24,6 +24,9 @@ from ..models import (
logger = logging.getLogger(__name__)
# Sentinel for distinguishing "no argument passed" from "explicitly passing None"
_UNSET = object()
class AuditLogger:
"""
@@ -142,8 +145,10 @@ class AuditLogger:
# Add hash chain for tamper detection
if self._enable_hash_chain:
event_hash = self._compute_hash(event)
sanitized_details["_hash"] = event_hash
sanitized_details["_prev_hash"] = self._last_hash
# Modify event.details directly (not sanitized_details)
# to ensure the hash is stored on the actual event
event.details["_hash"] = event_hash
event.details["_prev_hash"] = self._last_hash
self._last_hash = event_hash
self._buffer.append(event)
@@ -415,7 +420,8 @@ class AuditLogger:
)
if stored_hash:
computed = self._compute_hash(event)
# Pass prev_hash to compute hash with correct chain position
computed = self._compute_hash(event, prev_hash=prev_hash)
if computed != stored_hash:
issues.append(
f"Hash mismatch at event {event.id}: "
@@ -462,9 +468,23 @@ class AuditLogger:
return sanitized
def _compute_hash(self, event: AuditEvent) -> str:
"""Compute hash for an event (excluding hash fields)."""
data = {
def _compute_hash(
self, event: AuditEvent, prev_hash: str | None | object = _UNSET
) -> str:
"""Compute hash for an event (excluding hash fields).
Args:
event: The audit event to hash.
prev_hash: Optional previous hash to use instead of self._last_hash.
Pass this during verification to use the correct chain.
Use None explicitly to indicate no previous hash.
"""
# Use passed prev_hash if explicitly provided, otherwise use instance state
effective_prev: str | None = (
self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment]
)
data: dict[str, str | dict[str, str] | None] = {
"id": event.id,
"event_type": event.event_type.value,
"timestamp": event.timestamp.isoformat(),
@@ -480,8 +500,8 @@ class AuditLogger:
"correlation_id": event.correlation_id,
}
if self._last_hash:
data["_prev_hash"] = self._last_hash
if effective_prev:
data["_prev_hash"] = effective_prev
serialized = json.dumps(data, sort_keys=True, default=str)
return hashlib.sha256(serialized.encode()).hexdigest()

View File

@@ -0,0 +1,466 @@
"""
Tests for Context Management API Routes.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from app.main import app
from app.models.user import User
from app.services.context import (
AssembledContext,
AssemblyTimeoutError,
BudgetExceededError,
ContextEngine,
TokenBudget,
)
from app.services.mcp import MCPClientManager
@pytest.fixture
def mock_mcp_client():
"""Create a mock MCP client manager."""
client = MagicMock(spec=MCPClientManager)
client.is_initialized = True
return client
@pytest.fixture
def mock_context_engine(mock_mcp_client):
"""Create a mock ContextEngine."""
engine = MagicMock(spec=ContextEngine)
engine._mcp = mock_mcp_client
return engine
@pytest.fixture
def mock_superuser():
"""Create a mock superuser."""
user = MagicMock(spec=User)
user.id = "00000000-0000-0000-0000-000000000001"
user.is_superuser = True
user.email = "admin@example.com"
return user
@pytest.fixture
def client(mock_mcp_client, mock_context_engine, mock_superuser):
"""Create a FastAPI test client with mocked dependencies."""
from app.api.dependencies.permissions import require_superuser
from app.api.routes.context import get_context_engine
from app.services.mcp import get_mcp_client
# Override dependencies
async def override_get_mcp_client():
return mock_mcp_client
async def override_get_context_engine():
return mock_context_engine
async def override_require_superuser():
return mock_superuser
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
app.dependency_overrides[get_context_engine] = override_get_context_engine
app.dependency_overrides[require_superuser] = override_require_superuser
with patch("app.main.check_database_health", return_value=True):
yield TestClient(app)
# Clean up
app.dependency_overrides.clear()
class TestContextHealth:
"""Tests for GET /context/health endpoint."""
def test_health_check_success(self, client, mock_context_engine, mock_mcp_client):
"""Test context engine health check."""
mock_context_engine.get_stats = AsyncMock(
return_value={
"cache": {"hits": 10, "misses": 5},
"settings": {"cache_enabled": True},
}
)
response = client.get("/api/v1/context/health")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "healthy"
assert "mcp_connected" in data
assert "cache_enabled" in data
class TestAssembleContext:
"""Tests for POST /context/assemble endpoint."""
def test_assemble_context_success(self, client, mock_context_engine):
"""Test successful context assembly."""
# Create mock assembled context
mock_result = MagicMock(spec=AssembledContext)
mock_result.content = "Assembled context content"
mock_result.total_tokens = 500
mock_result.context_count = 2
mock_result.excluded_count = 0
mock_result.assembly_time_ms = 50.5
mock_result.metadata = {}
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=4000,
system=500,
knowledge=1500,
conversation=1000,
tools=500,
response_reserve=500,
)
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "What is the auth flow?",
"model": "claude-3-sonnet",
"system_prompt": "You are a helpful assistant.",
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["content"] == "Assembled context content"
assert data["total_tokens"] == 500
assert data["context_count"] == 2
assert data["compressed"] is False
assert "budget_used_percent" in data
def test_assemble_context_with_conversation(self, client, mock_context_engine):
"""Test context assembly with conversation history."""
mock_result = MagicMock(spec=AssembledContext)
mock_result.content = "Context with history"
mock_result.total_tokens = 800
mock_result.context_count = 1
mock_result.excluded_count = 0
mock_result.assembly_time_ms = 30.0
mock_result.metadata = {}
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=4000,
system=500,
knowledge=1500,
conversation=1000,
tools=500,
response_reserve=500,
)
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "Continue the discussion",
"conversation_history": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
],
},
)
assert response.status_code == status.HTTP_200_OK
call_args = mock_context_engine.assemble_context.call_args
assert call_args.kwargs["conversation_history"] == [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
def test_assemble_context_with_tool_results(self, client, mock_context_engine):
"""Test context assembly with tool results."""
mock_result = MagicMock(spec=AssembledContext)
mock_result.content = "Context with tools"
mock_result.total_tokens = 600
mock_result.context_count = 1
mock_result.excluded_count = 0
mock_result.assembly_time_ms = 25.0
mock_result.metadata = {}
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=4000,
system=500,
knowledge=1500,
conversation=1000,
tools=500,
response_reserve=500,
)
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "What did the search find?",
"tool_results": [
{
"tool_name": "search_knowledge",
"content": {"results": ["item1", "item2"]},
"status": "success",
}
],
},
)
assert response.status_code == status.HTTP_200_OK
call_args = mock_context_engine.assemble_context.call_args
assert len(call_args.kwargs["tool_results"]) == 1
def test_assemble_context_timeout(self, client, mock_context_engine):
"""Test context assembly timeout error."""
mock_context_engine.assemble_context = AsyncMock(
side_effect=AssemblyTimeoutError("Assembly exceeded 5000ms limit")
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "test",
},
)
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
def test_assemble_context_budget_exceeded(self, client, mock_context_engine):
"""Test context assembly budget exceeded error."""
mock_context_engine.assemble_context = AsyncMock(
side_effect=BudgetExceededError("Token budget exceeded: 5000 > 4000")
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "test",
},
)
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
def test_assemble_context_validation_error(self, client):
"""Test context assembly with invalid request."""
response = client.post(
"/api/v1/context/assemble",
json={}, # Missing required fields
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
class TestCountTokens:
"""Tests for POST /context/count-tokens endpoint."""
def test_count_tokens_success(self, client, mock_context_engine):
"""Test successful token counting."""
mock_context_engine.count_tokens = AsyncMock(return_value=42)
response = client.post(
"/api/v1/context/count-tokens",
json={
"content": "This is some test content.",
"model": "claude-3-sonnet",
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["token_count"] == 42
assert data["model"] == "claude-3-sonnet"
def test_count_tokens_without_model(self, client, mock_context_engine):
"""Test token counting without specifying model."""
mock_context_engine.count_tokens = AsyncMock(return_value=100)
response = client.post(
"/api/v1/context/count-tokens",
json={"content": "Some content to count."},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["token_count"] == 100
assert data["model"] is None
class TestGetBudget:
"""Tests for GET /context/budget/{model} endpoint."""
def test_get_budget_success(self, client, mock_context_engine):
"""Test getting token budget for a model."""
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=100000,
system=10000,
knowledge=40000,
conversation=30000,
tools=10000,
response_reserve=10000,
)
)
response = client.get("/api/v1/context/budget/claude-3-opus")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["model"] == "claude-3-opus"
assert data["total_tokens"] == 100000
assert data["system_tokens"] == 10000
assert data["knowledge_tokens"] == 40000
def test_get_budget_with_max_tokens(self, client, mock_context_engine):
"""Test getting budget with custom max tokens."""
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=2000,
system=200,
knowledge=800,
conversation=600,
tools=200,
response_reserve=200,
)
)
response = client.get("/api/v1/context/budget/gpt-4?max_tokens=2000")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total_tokens"] == 2000
class TestGetStats:
"""Tests for GET /context/stats endpoint."""
def test_get_stats_success(self, client, mock_context_engine):
"""Test getting engine statistics."""
mock_context_engine.get_stats = AsyncMock(
return_value={
"cache": {
"hits": 100,
"misses": 25,
"hit_rate": 0.8,
},
"settings": {
"compression_threshold": 0.9,
"max_assembly_time_ms": 5000,
"cache_enabled": True,
},
}
)
response = client.get("/api/v1/context/stats")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["cache"]["hits"] == 100
assert data["settings"]["cache_enabled"] is True
class TestInvalidateCache:
"""Tests for POST /context/cache/invalidate endpoint."""
def test_invalidate_cache_by_project(self, client, mock_context_engine):
"""Test cache invalidation by project ID."""
mock_context_engine.invalidate_cache = AsyncMock(return_value=5)
response = client.post(
"/api/v1/context/cache/invalidate?project_id=test-project"
)
assert response.status_code == status.HTTP_204_NO_CONTENT
mock_context_engine.invalidate_cache.assert_called_once()
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
assert call_kwargs["project_id"] == "test-project"
def test_invalidate_cache_by_pattern(self, client, mock_context_engine):
"""Test cache invalidation by pattern."""
mock_context_engine.invalidate_cache = AsyncMock(return_value=10)
response = client.post("/api/v1/context/cache/invalidate?pattern=*auth*")
assert response.status_code == status.HTTP_204_NO_CONTENT
mock_context_engine.invalidate_cache.assert_called_once()
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
assert call_kwargs["pattern"] == "*auth*"
def test_invalidate_cache_all(self, client, mock_context_engine):
"""Test invalidating all cache entries."""
mock_context_engine.invalidate_cache = AsyncMock(return_value=100)
response = client.post("/api/v1/context/cache/invalidate")
assert response.status_code == status.HTTP_204_NO_CONTENT
class TestContextEndpointsEdgeCases:
"""Edge case tests for Context endpoints."""
def test_context_content_type(self, client, mock_context_engine):
"""Test that endpoints return JSON content type."""
mock_context_engine.get_stats = AsyncMock(
return_value={"cache": {}, "settings": {}}
)
response = client.get("/api/v1/context/health")
assert "application/json" in response.headers["content-type"]
def test_assemble_context_with_knowledge_query(self, client, mock_context_engine):
"""Test context assembly with knowledge base query."""
mock_result = MagicMock(spec=AssembledContext)
mock_result.content = "Context with knowledge"
mock_result.total_tokens = 1000
mock_result.context_count = 3
mock_result.excluded_count = 0
mock_result.assembly_time_ms = 100.0
mock_result.metadata = {
"compressed_contexts": 1
} # Indicates compression happened
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
mock_context_engine.get_budget_for_model = AsyncMock(
return_value=TokenBudget(
total=4000,
system=500,
knowledge=1500,
conversation=1000,
tools=500,
response_reserve=500,
)
)
response = client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "How does authentication work?",
"knowledge_query": "authentication flow implementation",
"knowledge_limit": 5,
},
)
assert response.status_code == status.HTTP_200_OK
call_kwargs = mock_context_engine.assemble_context.call_args.kwargs
assert call_kwargs["knowledge_query"] == "authentication flow implementation"
assert call_kwargs["knowledge_limit"] == 5

View File

@@ -0,0 +1,646 @@
"""
Agent E2E Workflow Tests.
Tests complete workflows for AI agents including:
- Agent type management (admin-only)
- Agent instance spawning and lifecycle
- Agent status transitions (pause/resume/terminate)
- Authorization and access control
Usage:
make test-e2e # Run all E2E tests
"""
from uuid import uuid4
import pytest
pytestmark = [
pytest.mark.e2e,
pytest.mark.postgres,
pytest.mark.asyncio,
]
class TestAgentTypesAdminWorkflows:
"""Test agent type management (admin-only operations)."""
async def test_create_agent_type_requires_superuser(self, e2e_client):
"""Test that creating agent types requires superuser privileges."""
# Register regular user
email = f"regular-{uuid4().hex[:8]}@example.com"
password = "RegularPass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Regular",
"last_name": "User",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
# Try to create agent type
response = await e2e_client.post(
"/api/v1/agent-types",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"name": "Test Agent",
"slug": f"test-agent-{uuid4().hex[:8]}",
"personality_prompt": "You are a helpful assistant.",
"primary_model": "claude-3-sonnet",
},
)
assert response.status_code == 403
async def test_superuser_can_create_agent_type(self, e2e_client, e2e_superuser):
"""Test that superuser can create and manage agent types."""
slug = f"test-type-{uuid4().hex[:8]}"
# Create agent type
create_resp = await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Product Owner Agent",
"slug": slug,
"description": "A product owner agent for requirements gathering",
"expertise": ["requirements", "user_stories", "prioritization"],
"personality_prompt": "You are a product owner focused on delivering value.",
"primary_model": "claude-3-opus",
"fallback_models": ["claude-3-sonnet"],
"model_params": {"temperature": 0.7, "max_tokens": 4000},
"mcp_servers": ["knowledge-base"],
"is_active": True,
},
)
assert create_resp.status_code == 201, f"Failed: {create_resp.text}"
agent_type = create_resp.json()
assert agent_type["name"] == "Product Owner Agent"
assert agent_type["slug"] == slug
assert agent_type["primary_model"] == "claude-3-opus"
assert agent_type["is_active"] is True
assert "requirements" in agent_type["expertise"]
async def test_list_agent_types_public(self, e2e_client, e2e_superuser):
"""Test that any authenticated user can list agent types."""
# First create an agent type as superuser
slug = f"list-test-{uuid4().hex[:8]}"
await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": f"List Test Agent {slug}",
"slug": slug,
"personality_prompt": "Test agent.",
"primary_model": "claude-3-sonnet",
},
)
# Register regular user
email = f"lister-{uuid4().hex[:8]}@example.com"
password = "ListerPass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "List",
"last_name": "User",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
# List agent types as regular user
list_resp = await e2e_client.get(
"/api/v1/agent-types",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert list_resp.status_code == 200
data = list_resp.json()
assert "data" in data
assert "pagination" in data
assert data["pagination"]["total"] >= 1
async def test_get_agent_type_by_slug(self, e2e_client, e2e_superuser):
"""Test getting agent type by slug."""
slug = f"slug-test-{uuid4().hex[:8]}"
# Create agent type
await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": f"Slug Test {slug}",
"slug": slug,
"personality_prompt": "Test agent.",
"primary_model": "claude-3-sonnet",
},
)
# Get by slug (route is /slug/{slug}, not /by-slug/{slug})
get_resp = await e2e_client.get(
f"/api/v1/agent-types/slug/{slug}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert get_resp.status_code == 200
data = get_resp.json()
assert data["slug"] == slug
async def test_update_agent_type(self, e2e_client, e2e_superuser):
"""Test updating an agent type."""
slug = f"update-test-{uuid4().hex[:8]}"
# Create agent type
create_resp = await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Original Name",
"slug": slug,
"personality_prompt": "Original prompt.",
"primary_model": "claude-3-sonnet",
},
)
agent_type_id = create_resp.json()["id"]
# Update agent type
update_resp = await e2e_client.patch(
f"/api/v1/agent-types/{agent_type_id}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Updated Name",
"description": "Added description",
},
)
assert update_resp.status_code == 200
updated = update_resp.json()
assert updated["name"] == "Updated Name"
assert updated["description"] == "Added description"
assert updated["personality_prompt"] == "Original prompt." # Unchanged
class TestAgentInstanceWorkflows:
"""Test agent instance spawning and lifecycle."""
async def test_spawn_agent_workflow(self, e2e_client, e2e_superuser):
"""Test complete workflow: create type -> create project -> spawn agent."""
# 1. Create agent type as superuser
type_slug = f"spawn-test-type-{uuid4().hex[:8]}"
type_resp = await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Spawn Test Agent",
"slug": type_slug,
"personality_prompt": "You are a helpful agent.",
"primary_model": "claude-3-sonnet",
},
)
assert type_resp.status_code == 201
agent_type = type_resp.json()
agent_type_id = agent_type["id"]
# 2. Create a project (superuser can create projects too)
project_slug = f"spawn-test-project-{uuid4().hex[:8]}"
project_resp = await e2e_client.post(
"/api/v1/projects",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Spawn Test Project", "slug": project_slug},
)
assert project_resp.status_code == 201
project = project_resp.json()
project_id = project["id"]
# 3. Spawn agent instance
spawn_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/agents",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"agent_type_id": agent_type_id,
"project_id": project_id,
"name": "My PO Agent",
},
)
assert spawn_resp.status_code == 201, f"Failed: {spawn_resp.text}"
agent = spawn_resp.json()
assert agent["name"] == "My PO Agent"
assert agent["status"] == "idle"
assert agent["project_id"] == project_id
assert agent["agent_type_id"] == agent_type_id
async def test_list_project_agents(self, e2e_client, e2e_superuser):
"""Test listing agents in a project."""
# Setup: Create agent type and project
type_slug = f"list-agents-type-{uuid4().hex[:8]}"
type_resp = await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "List Agents Type",
"slug": type_slug,
"personality_prompt": "Test agent.",
"primary_model": "claude-3-sonnet",
},
)
agent_type_id = type_resp.json()["id"]
project_slug = f"list-agents-project-{uuid4().hex[:8]}"
project_resp = await e2e_client.post(
"/api/v1/projects",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "List Agents Project", "slug": project_slug},
)
project_id = project_resp.json()["id"]
# Spawn multiple agents
for i in range(3):
await e2e_client.post(
f"/api/v1/projects/{project_id}/agents",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"agent_type_id": agent_type_id,
"project_id": project_id,
"name": f"Agent {i + 1}",
},
)
# List agents
list_resp = await e2e_client.get(
f"/api/v1/projects/{project_id}/agents",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert list_resp.status_code == 200
data = list_resp.json()
assert data["pagination"]["total"] == 3
assert len(data["data"]) == 3
class TestAgentLifecycle:
"""Test agent lifecycle operations (pause/resume/terminate)."""
async def test_agent_pause_and_resume(self, e2e_client, e2e_superuser):
"""Test pausing and resuming an agent."""
# Setup: Create agent type, project, and agent
type_slug = f"pause-test-type-{uuid4().hex[:8]}"
type_resp = await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Pause Test Type",
"slug": type_slug,
"personality_prompt": "Test agent.",
"primary_model": "claude-3-sonnet",
},
)
agent_type_id = type_resp.json()["id"]
project_slug = f"pause-test-project-{uuid4().hex[:8]}"
project_resp = await e2e_client.post(
"/api/v1/projects",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Pause Test Project", "slug": project_slug},
)
project_id = project_resp.json()["id"]
spawn_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/agents",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"agent_type_id": agent_type_id,
"project_id": project_id,
"name": "Pausable Agent",
},
)
agent_id = spawn_resp.json()["id"]
assert spawn_resp.json()["status"] == "idle"
# Pause agent
pause_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/agents/{agent_id}/pause",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert pause_resp.status_code == 200, f"Failed: {pause_resp.text}"
assert pause_resp.json()["status"] == "paused"
# Resume agent
resume_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert resume_resp.status_code == 200, f"Failed: {resume_resp.text}"
assert resume_resp.json()["status"] == "idle"
async def test_agent_terminate(self, e2e_client, e2e_superuser):
"""Test terminating an agent."""
# Setup
type_slug = f"terminate-type-{uuid4().hex[:8]}"
type_resp = await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Terminate Type",
"slug": type_slug,
"personality_prompt": "Test agent.",
"primary_model": "claude-3-sonnet",
},
)
agent_type_id = type_resp.json()["id"]
project_slug = f"terminate-project-{uuid4().hex[:8]}"
project_resp = await e2e_client.post(
"/api/v1/projects",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Terminate Project", "slug": project_slug},
)
project_id = project_resp.json()["id"]
spawn_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/agents",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"agent_type_id": agent_type_id,
"project_id": project_id,
"name": "To Be Terminated",
},
)
agent_id = spawn_resp.json()["id"]
# Terminate agent (returns MessageResponse, not agent status)
terminate_resp = await e2e_client.delete(
f"/api/v1/projects/{project_id}/agents/{agent_id}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert terminate_resp.status_code == 200
assert "message" in terminate_resp.json()
# Verify terminated agent cannot be resumed (returns 400 or 422)
resume_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert resume_resp.status_code in [400, 422] # Invalid transition
class TestAgentAccessControl:
"""Test agent access control and authorization."""
async def test_user_cannot_access_other_project_agents(
self, e2e_client, e2e_superuser
):
"""Test that users cannot access agents in projects they don't own."""
# Superuser creates agent type
type_slug = f"access-type-{uuid4().hex[:8]}"
type_resp = await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Access Type",
"slug": type_slug,
"personality_prompt": "Test agent.",
"primary_model": "claude-3-sonnet",
},
)
agent_type_id = type_resp.json()["id"]
# Superuser creates project and spawns agent
project_slug = f"protected-project-{uuid4().hex[:8]}"
project_resp = await e2e_client.post(
"/api/v1/projects",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Protected Project", "slug": project_slug},
)
project_id = project_resp.json()["id"]
spawn_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/agents",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"agent_type_id": agent_type_id,
"project_id": project_id,
"name": "Protected Agent",
},
)
agent_id = spawn_resp.json()["id"]
# Create a different user
email = f"other-user-{uuid4().hex[:8]}@example.com"
password = "OtherPass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Other",
"last_name": "User",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
other_tokens = login_resp.json()
# Other user tries to access the agent
get_resp = await e2e_client.get(
f"/api/v1/projects/{project_id}/agents/{agent_id}",
headers={"Authorization": f"Bearer {other_tokens['access_token']}"},
)
# Should be forbidden or not found
assert get_resp.status_code in [403, 404]
async def test_cannot_spawn_with_inactive_agent_type(
self, e2e_client, e2e_superuser
):
"""Test that agents cannot be spawned from inactive agent types."""
# Create agent type
type_slug = f"inactive-type-{uuid4().hex[:8]}"
type_resp = await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Inactive Type",
"slug": type_slug,
"personality_prompt": "Test agent.",
"primary_model": "claude-3-sonnet",
"is_active": True,
},
)
agent_type_id = type_resp.json()["id"]
# Deactivate the agent type
await e2e_client.patch(
f"/api/v1/agent-types/{agent_type_id}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"is_active": False},
)
# Create project
project_slug = f"inactive-spawn-project-{uuid4().hex[:8]}"
project_resp = await e2e_client.post(
"/api/v1/projects",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Inactive Spawn Project", "slug": project_slug},
)
project_id = project_resp.json()["id"]
# Try to spawn agent with inactive type
spawn_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/agents",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"agent_type_id": agent_type_id,
"project_id": project_id,
"name": "Should Fail",
},
)
# 422 is correct for validation errors per REST conventions
assert spawn_resp.status_code == 422
class TestAgentMetrics:
"""Test agent metrics endpoint."""
async def test_get_agent_metrics(self, e2e_client, e2e_superuser):
"""Test retrieving agent metrics."""
# Setup
type_slug = f"metrics-type-{uuid4().hex[:8]}"
type_resp = await e2e_client.post(
"/api/v1/agent-types",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Metrics Type",
"slug": type_slug,
"personality_prompt": "Test agent.",
"primary_model": "claude-3-sonnet",
},
)
agent_type_id = type_resp.json()["id"]
project_slug = f"metrics-project-{uuid4().hex[:8]}"
project_resp = await e2e_client.post(
"/api/v1/projects",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Metrics Project", "slug": project_slug},
)
project_id = project_resp.json()["id"]
spawn_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/agents",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"agent_type_id": agent_type_id,
"project_id": project_id,
"name": "Metrics Agent",
},
)
agent_id = spawn_resp.json()["id"]
# Get metrics
metrics_resp = await e2e_client.get(
f"/api/v1/projects/{project_id}/agents/{agent_id}/metrics",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert metrics_resp.status_code == 200
metrics = metrics_resp.json()
# Verify AgentInstanceMetrics structure
assert "total_instances" in metrics
assert "active_instances" in metrics
assert "idle_instances" in metrics
assert "total_tasks_completed" in metrics
assert "total_tokens_used" in metrics
assert "total_cost_incurred" in metrics

View File

@@ -0,0 +1,460 @@
"""
MCP and Context Engine E2E Workflow Tests.
Tests complete workflows involving MCP servers and the Context Engine
against real PostgreSQL. These tests verify:
- MCP server listing and tool discovery
- Context engine operations
- Admin-only MCP operations with proper authentication
- Error handling for MCP operations
Usage:
make test-e2e # Run all E2E tests
"""
from uuid import uuid4
import pytest
pytestmark = [
pytest.mark.e2e,
pytest.mark.postgres,
pytest.mark.asyncio,
]
class TestMCPServerDiscovery:
"""Test MCP server listing and discovery workflows."""
async def test_list_mcp_servers(self, e2e_client):
"""Test listing MCP servers returns expected configuration."""
response = await e2e_client.get("/api/v1/mcp/servers")
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
# Should have servers configured
assert "servers" in data
assert "total" in data
assert isinstance(data["servers"], list)
# Should have at least llm-gateway and knowledge-base
server_names = [s["name"] for s in data["servers"]]
assert "llm-gateway" in server_names
assert "knowledge-base" in server_names
async def test_list_all_mcp_tools(self, e2e_client):
"""Test listing all tools from all MCP servers."""
response = await e2e_client.get("/api/v1/mcp/tools")
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
assert "tools" in data
assert "total" in data
assert isinstance(data["tools"], list)
async def test_mcp_health_check(self, e2e_client):
"""Test MCP health check returns server status."""
response = await e2e_client.get("/api/v1/mcp/health")
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
assert "servers" in data
assert "healthy_count" in data
assert "unhealthy_count" in data
assert "total" in data
async def test_list_circuit_breakers(self, e2e_client):
"""Test listing circuit breaker status."""
response = await e2e_client.get("/api/v1/mcp/circuit-breakers")
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
assert "circuit_breakers" in data
assert isinstance(data["circuit_breakers"], list)
class TestMCPServerTools:
"""Test MCP server tool listing."""
async def test_list_llm_gateway_tools(self, e2e_client):
"""Test listing tools from LLM Gateway server."""
response = await e2e_client.get("/api/v1/mcp/servers/llm-gateway/tools")
# May return 200 with tools or 404 if server not connected
assert response.status_code in [200, 404, 502]
if response.status_code == 200:
data = response.json()
assert "tools" in data
assert "total" in data
async def test_list_knowledge_base_tools(self, e2e_client):
"""Test listing tools from Knowledge Base server."""
response = await e2e_client.get("/api/v1/mcp/servers/knowledge-base/tools")
# May return 200 with tools or 404/502 if server not connected
assert response.status_code in [200, 404, 502]
if response.status_code == 200:
data = response.json()
assert "tools" in data
assert "total" in data
async def test_invalid_server_returns_404(self, e2e_client):
"""Test that invalid server name returns 404."""
response = await e2e_client.get("/api/v1/mcp/servers/nonexistent-server/tools")
assert response.status_code == 404
class TestContextEngineWorkflows:
"""Test Context Engine operations."""
async def test_context_engine_health(self, e2e_client):
"""Test context engine health endpoint."""
response = await e2e_client.get("/api/v1/context/health")
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
assert data["status"] == "healthy"
assert "mcp_connected" in data
assert "cache_enabled" in data
async def test_get_token_budget_claude_sonnet(self, e2e_client):
"""Test getting token budget for Claude 3 Sonnet."""
response = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
assert data["model"] == "claude-3-sonnet"
assert "total_tokens" in data
assert "system_tokens" in data
assert "knowledge_tokens" in data
assert "conversation_tokens" in data
assert "tool_tokens" in data
assert "response_reserve" in data
# Verify budget allocation makes sense
assert data["total_tokens"] > 0
total_allocated = (
data["system_tokens"]
+ data["knowledge_tokens"]
+ data["conversation_tokens"]
+ data["tool_tokens"]
+ data["response_reserve"]
)
assert total_allocated <= data["total_tokens"]
async def test_get_token_budget_with_custom_max(self, e2e_client):
"""Test getting token budget with custom max tokens."""
response = await e2e_client.get(
"/api/v1/context/budget/claude-3-sonnet",
params={"max_tokens": 50000},
)
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
assert data["model"] == "claude-3-sonnet"
# Custom max should be respected or capped
assert data["total_tokens"] <= 50000
async def test_count_tokens(self, e2e_client):
"""Test token counting endpoint."""
response = await e2e_client.post(
"/api/v1/context/count-tokens",
json={
"content": "Hello, this is a test message for token counting.",
"model": "claude-3-sonnet",
},
)
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
assert "token_count" in data
assert data["token_count"] > 0
assert data["model"] == "claude-3-sonnet"
class TestAdminMCPOperations:
"""Test admin-only MCP operations require authentication."""
async def test_tool_call_requires_auth(self, e2e_client):
"""Test that tool execution requires authentication."""
response = await e2e_client.post(
"/api/v1/mcp/call",
json={
"server": "llm-gateway",
"tool": "count_tokens",
"arguments": {"text": "test"},
},
)
# Should require authentication
assert response.status_code in [401, 403]
async def test_circuit_reset_requires_auth(self, e2e_client):
"""Test that circuit breaker reset requires authentication."""
response = await e2e_client.post(
"/api/v1/mcp/circuit-breakers/llm-gateway/reset"
)
assert response.status_code in [401, 403]
async def test_server_reconnect_requires_auth(self, e2e_client):
"""Test that server reconnect requires authentication."""
response = await e2e_client.post("/api/v1/mcp/servers/llm-gateway/reconnect")
assert response.status_code in [401, 403]
async def test_context_stats_requires_auth(self, e2e_client):
"""Test that context stats requires authentication."""
response = await e2e_client.get("/api/v1/context/stats")
assert response.status_code in [401, 403]
async def test_context_assemble_requires_auth(self, e2e_client):
"""Test that context assembly requires authentication."""
response = await e2e_client.post(
"/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "test query",
"model": "claude-3-sonnet",
},
)
assert response.status_code in [401, 403]
async def test_cache_invalidate_requires_auth(self, e2e_client):
"""Test that cache invalidation requires authentication."""
response = await e2e_client.post("/api/v1/context/cache/invalidate")
assert response.status_code in [401, 403]
class TestAdminMCPWithAuthentication:
"""Test admin MCP operations with superuser authentication."""
async def test_superuser_can_get_context_stats(self, e2e_client, e2e_superuser):
"""Test that superuser can get context engine stats."""
response = await e2e_client.get(
"/api/v1/context/stats",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
assert "cache" in data
assert "settings" in data
@pytest.mark.skip(
reason="Requires MCP servers (llm-gateway, knowledge-base) to be running"
)
async def test_superuser_can_assemble_context(self, e2e_client, e2e_superuser):
"""Test that superuser can assemble context."""
response = await e2e_client.post(
"/api/v1/context/assemble",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"project_id": f"test-project-{uuid4().hex[:8]}",
"agent_id": f"test-agent-{uuid4().hex[:8]}",
"query": "What is the status of the project?",
"model": "claude-3-sonnet",
"system_prompt": "You are a helpful assistant.",
"compress": True,
"use_cache": False,
},
)
assert response.status_code == 200, f"Failed: {response.text}"
data = response.json()
assert "content" in data
assert "total_tokens" in data
assert "context_count" in data
assert "budget_used_percent" in data
assert "metadata" in data
async def test_superuser_can_invalidate_cache(self, e2e_client, e2e_superuser):
"""Test that superuser can invalidate cache."""
response = await e2e_client.post(
"/api/v1/context/cache/invalidate",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
params={"project_id": "test-project"},
)
assert response.status_code == 204
async def test_regular_user_cannot_access_admin_operations(self, e2e_client):
"""Test that regular (non-superuser) cannot access admin operations."""
email = f"regular-{uuid4().hex[:8]}@example.com"
password = "RegularUser123!"
# Register regular user
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Regular",
"last_name": "User",
},
)
# Login
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
# Try to access admin endpoint
response = await e2e_client.get(
"/api/v1/context/stats",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
# Should be forbidden for non-superuser
assert response.status_code == 403
class TestMCPInputValidation:
"""Test input validation for MCP endpoints."""
async def test_server_name_max_length(self, e2e_client):
"""Test that server name has max length validation."""
long_name = "a" * 100 # Exceeds 64 char limit
response = await e2e_client.get(f"/api/v1/mcp/servers/{long_name}/tools")
assert response.status_code == 422
async def test_server_name_invalid_characters(self, e2e_client):
"""Test that server name rejects invalid characters."""
invalid_name = "server@name!invalid"
response = await e2e_client.get(f"/api/v1/mcp/servers/{invalid_name}/tools")
assert response.status_code == 422
async def test_token_count_empty_content(self, e2e_client):
"""Test token counting with empty content."""
response = await e2e_client.post(
"/api/v1/context/count-tokens",
json={"content": ""},
)
# Empty content is valid, should return 0 tokens
if response.status_code == 200:
data = response.json()
assert data["token_count"] == 0
else:
# Or it might be rejected as invalid
assert response.status_code == 422
class TestMCPWorkflowIntegration:
"""Test complete MCP workflows end-to-end."""
async def test_discovery_to_budget_workflow(self, e2e_client):
"""Test complete workflow: discover servers -> check budget -> ready for use."""
# 1. Discover available servers
servers_resp = await e2e_client.get("/api/v1/mcp/servers")
assert servers_resp.status_code == 200
servers = servers_resp.json()["servers"]
assert len(servers) > 0
# 2. Check context engine health
health_resp = await e2e_client.get("/api/v1/context/health")
assert health_resp.status_code == 200
health = health_resp.json()
assert health["status"] == "healthy"
# 3. Get token budget for a model
budget_resp = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
assert budget_resp.status_code == 200
budget = budget_resp.json()
# 4. Verify system is ready for context assembly
assert budget["total_tokens"] > 0
assert health["mcp_connected"] is True
@pytest.mark.skip(
reason="Requires MCP servers (llm-gateway, knowledge-base) to be running"
)
async def test_full_context_assembly_workflow(self, e2e_client, e2e_superuser):
"""Test complete context assembly workflow with superuser."""
project_id = f"e2e-project-{uuid4().hex[:8]}"
agent_id = f"e2e-agent-{uuid4().hex[:8]}"
# 1. Check budget before assembly
budget_resp = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
assert budget_resp.status_code == 200
_ = budget_resp.json() # Verify valid response
# 2. Count tokens in sample content
count_resp = await e2e_client.post(
"/api/v1/context/count-tokens",
json={"content": "This is a test message for context assembly."},
)
assert count_resp.status_code == 200
token_count = count_resp.json()["token_count"]
assert token_count > 0
# 3. Assemble context
assemble_resp = await e2e_client.post(
"/api/v1/context/assemble",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"project_id": project_id,
"agent_id": agent_id,
"query": "Summarize the current project status",
"model": "claude-3-sonnet",
"system_prompt": "You are a project management assistant.",
"task_description": "Generate a status report",
"conversation_history": [
{"role": "user", "content": "What's the project status?"},
{
"role": "assistant",
"content": "Let me check the current status.",
},
],
"compress": True,
"use_cache": False,
},
)
assert assemble_resp.status_code == 200
assembled = assemble_resp.json()
# 4. Verify assembly results
assert assembled["total_tokens"] > 0
assert assembled["context_count"] > 0
assert assembled["budget_used_percent"] > 0
assert assembled["budget_used_percent"] <= 100
# 5. Get stats to verify the operation was recorded
stats_resp = await e2e_client.get(
"/api/v1/context/stats",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert stats_resp.status_code == 200

View File

@@ -0,0 +1,684 @@
"""
Project and Agent E2E Workflow Tests.
Tests complete project management workflows with real PostgreSQL:
- Project CRUD and lifecycle management
- Agent spawning and lifecycle
- Issue management within projects
- Sprint planning and execution
Usage:
make test-e2e # Run all E2E tests
"""
from datetime import date, timedelta
from uuid import uuid4
import pytest
pytestmark = [
pytest.mark.e2e,
pytest.mark.postgres,
pytest.mark.asyncio,
]
class TestProjectCRUDWorkflows:
"""Test complete project CRUD workflows."""
async def test_create_project_workflow(self, e2e_client):
"""Test creating a project as authenticated user."""
# Register and login
email = f"project-owner-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Project",
"last_name": "Owner",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
# Create project
project_slug = f"test-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"name": "E2E Test Project",
"slug": project_slug,
"description": "A project for E2E testing",
"autonomy_level": "milestone",
},
)
assert create_resp.status_code == 201, f"Failed: {create_resp.text}"
project = create_resp.json()
assert project["name"] == "E2E Test Project"
assert project["slug"] == project_slug
assert project["status"] == "active"
assert project["agent_count"] == 0
assert project["issue_count"] == 0
async def test_list_projects_only_shows_owned(self, e2e_client):
"""Test that users only see their own projects."""
# Create two users with projects
users = []
for i in range(2):
email = f"user-{i}-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": f"User{i}",
"last_name": "Test",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
# Each user creates their own project
project_slug = f"user{i}-project-{uuid4().hex[:8]}"
await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"name": f"User {i} Project",
"slug": project_slug,
},
)
users.append({"email": email, "tokens": tokens, "slug": project_slug})
# User 0 should only see their project
list_resp = await e2e_client.get(
"/api/v1/projects",
headers={"Authorization": f"Bearer {users[0]['tokens']['access_token']}"},
)
assert list_resp.status_code == 200
data = list_resp.json()
slugs = [p["slug"] for p in data["data"]]
assert users[0]["slug"] in slugs
assert users[1]["slug"] not in slugs
async def test_project_lifecycle_pause_resume(self, e2e_client):
"""Test pausing and resuming a project."""
# Setup user and project
email = f"lifecycle-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Lifecycle",
"last_name": "Test",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
project_slug = f"lifecycle-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Lifecycle Project", "slug": project_slug},
)
project = create_resp.json()
project_id = project["id"]
# Pause the project
pause_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/pause",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert pause_resp.status_code == 200
assert pause_resp.json()["status"] == "paused"
# Resume the project
resume_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/resume",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert resume_resp.status_code == 200
assert resume_resp.json()["status"] == "active"
async def test_project_archive(self, e2e_client):
"""Test archiving a project (soft delete)."""
# Setup user and project
email = f"archive-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Archive",
"last_name": "Test",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
project_slug = f"archive-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Archive Project", "slug": project_slug},
)
project = create_resp.json()
project_id = project["id"]
# Archive the project
archive_resp = await e2e_client.delete(
f"/api/v1/projects/{project_id}",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert archive_resp.status_code == 200
assert archive_resp.json()["success"] is True
# Verify project is archived
get_resp = await e2e_client.get(
f"/api/v1/projects/{project_id}",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert get_resp.status_code == 200
assert get_resp.json()["status"] == "archived"
class TestIssueWorkflows:
"""Test issue management workflows within projects."""
async def test_create_and_list_issues(self, e2e_client):
"""Test creating and listing issues in a project."""
# Setup user and project
email = f"issue-test-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Issue",
"last_name": "Tester",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
project_slug = f"issue-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Issue Test Project", "slug": project_slug},
)
project = create_resp.json()
project_id = project["id"]
# Create multiple issues
issues = []
for i in range(3):
issue_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/issues",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"project_id": project_id,
"title": f"Test Issue {i + 1}",
"body": f"Description for issue {i + 1}",
"priority": ["low", "medium", "high"][i],
},
)
assert issue_resp.status_code == 201, f"Failed: {issue_resp.text}"
issues.append(issue_resp.json())
# List issues
list_resp = await e2e_client.get(
f"/api/v1/projects/{project_id}/issues",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert list_resp.status_code == 200
data = list_resp.json()
assert data["pagination"]["total"] == 3
async def test_issue_status_transitions(self, e2e_client):
"""Test issue status workflow transitions."""
# Setup user and project
email = f"status-test-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Status",
"last_name": "Tester",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
project_slug = f"status-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Status Test Project", "slug": project_slug},
)
project = create_resp.json()
project_id = project["id"]
# Create issue
issue_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/issues",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"project_id": project_id,
"title": "Status Workflow Issue",
"body": "Testing status transitions",
},
)
issue = issue_resp.json()
issue_id = issue["id"]
assert issue["status"] == "open"
# Transition through statuses
for new_status in ["in_progress", "in_review", "closed"]:
update_resp = await e2e_client.patch(
f"/api/v1/projects/{project_id}/issues/{issue_id}",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"status": new_status},
)
assert update_resp.status_code == 200, f"Failed: {update_resp.text}"
assert update_resp.json()["status"] == new_status
async def test_issue_filtering(self, e2e_client):
"""Test issue filtering by status and priority."""
# Setup user and project
email = f"filter-test-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Filter",
"last_name": "Tester",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
project_slug = f"filter-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Filter Test Project", "slug": project_slug},
)
project = create_resp.json()
project_id = project["id"]
# Create issues with different priorities
for priority in ["low", "medium", "high"]:
await e2e_client.post(
f"/api/v1/projects/{project_id}/issues",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"project_id": project_id,
"title": f"{priority.title()} Priority Issue",
"priority": priority,
},
)
# Filter by high priority
filter_resp = await e2e_client.get(
f"/api/v1/projects/{project_id}/issues",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
params={"priority": "high"},
)
assert filter_resp.status_code == 200
data = filter_resp.json()
assert data["pagination"]["total"] == 1
assert data["data"][0]["priority"] == "high"
class TestSprintWorkflows:
"""Test sprint planning and execution workflows."""
async def test_sprint_lifecycle(self, e2e_client):
"""Test complete sprint lifecycle: plan -> start -> complete."""
# Setup user and project
email = f"sprint-test-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Sprint",
"last_name": "Tester",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
project_slug = f"sprint-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Sprint Test Project", "slug": project_slug},
)
project = create_resp.json()
project_id = project["id"]
# Create sprint
today = date.today()
sprint_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/sprints",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"project_id": project_id,
"name": "Sprint 1",
"number": 1,
"goal": "Complete initial features",
"start_date": today.isoformat(),
"end_date": (today + timedelta(days=14)).isoformat(),
},
)
assert sprint_resp.status_code == 201, f"Failed: {sprint_resp.text}"
sprint = sprint_resp.json()
sprint_id = sprint["id"]
assert sprint["status"] == "planned"
# Start sprint
start_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/start",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert start_resp.status_code == 200, f"Failed: {start_resp.text}"
assert start_resp.json()["status"] == "active"
# Complete sprint
complete_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/complete",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert complete_resp.status_code == 200, f"Failed: {complete_resp.text}"
assert complete_resp.json()["status"] == "completed"
async def test_add_issues_to_sprint(self, e2e_client):
"""Test adding issues to a sprint."""
# Setup user and project
email = f"sprint-issues-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "SprintIssues",
"last_name": "Tester",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
project_slug = f"sprint-issues-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Sprint Issues Project", "slug": project_slug},
)
project = create_resp.json()
project_id = project["id"]
# Create sprint
today = date.today()
sprint_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/sprints",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"project_id": project_id,
"name": "Sprint 1",
"number": 1,
"start_date": today.isoformat(),
"end_date": (today + timedelta(days=14)).isoformat(),
},
)
assert sprint_resp.status_code == 201, f"Failed: {sprint_resp.text}"
sprint = sprint_resp.json()
sprint_id = sprint["id"]
# Create issue
issue_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/issues",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"project_id": project_id,
"title": "Sprint Issue",
"story_points": 5,
},
)
issue = issue_resp.json()
issue_id = issue["id"]
# Add issue to sprint
add_resp = await e2e_client.post(
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/issues",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
params={"issue_id": issue_id},
)
assert add_resp.status_code == 200, f"Failed: {add_resp.text}"
# Verify issue is in sprint
issue_check = await e2e_client.get(
f"/api/v1/projects/{project_id}/issues/{issue_id}",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert issue_check.json()["sprint_id"] == sprint_id
class TestCrossEntityValidation:
"""Test validation across related entities."""
async def test_cannot_access_other_users_project(self, e2e_client):
"""Test that users cannot access projects they don't own."""
# Create two users
owner_email = f"owner-{uuid4().hex[:8]}@example.com"
other_email = f"other-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
# Register owner
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": owner_email,
"password": password,
"first_name": "Owner",
"last_name": "User",
},
)
owner_tokens = (
await e2e_client.post(
"/api/v1/auth/login",
json={"email": owner_email, "password": password},
)
).json()
# Register other user
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": other_email,
"password": password,
"first_name": "Other",
"last_name": "User",
},
)
other_tokens = (
await e2e_client.post(
"/api/v1/auth/login",
json={"email": other_email, "password": password},
)
).json()
# Owner creates project
project_slug = f"private-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {owner_tokens['access_token']}"},
json={"name": "Private Project", "slug": project_slug},
)
project = create_resp.json()
project_id = project["id"]
# Other user tries to access
access_resp = await e2e_client.get(
f"/api/v1/projects/{project_id}",
headers={"Authorization": f"Bearer {other_tokens['access_token']}"},
)
assert access_resp.status_code == 403
async def test_duplicate_project_slug_rejected(self, e2e_client):
"""Test that duplicate project slugs are rejected."""
email = f"dup-test-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Dup",
"last_name": "Tester",
},
)
tokens = (
await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
).json()
slug = f"unique-slug-{uuid4().hex[:8]}"
# First creation should succeed
resp1 = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "First Project", "slug": slug},
)
assert resp1.status_code == 201
# Second creation with same slug should fail
resp2 = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Second Project", "slug": slug},
)
assert resp2.status_code == 409 # Conflict
class TestIssueStats:
"""Test issue statistics endpoints."""
async def test_issue_stats_aggregation(self, e2e_client):
"""Test that issue stats are correctly aggregated."""
email = f"stats-test-{uuid4().hex[:8]}@example.com"
password = "SecurePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Stats",
"last_name": "Tester",
},
)
tokens = (
await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
).json()
project_slug = f"stats-project-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/projects",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Stats Project", "slug": project_slug},
)
project = create_resp.json()
project_id = project["id"]
# Create issues with different priorities and story points
await e2e_client.post(
f"/api/v1/projects/{project_id}/issues",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"project_id": project_id,
"title": "High Priority",
"priority": "high",
"story_points": 8,
},
)
await e2e_client.post(
f"/api/v1/projects/{project_id}/issues",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"project_id": project_id,
"title": "Low Priority",
"priority": "low",
"story_points": 2,
},
)
# Get stats
stats_resp = await e2e_client.get(
f"/api/v1/projects/{project_id}/issues/stats",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert stats_resp.status_code == 200
stats = stats_resp.json()
assert stats["total"] == 2
assert stats["total_story_points"] == 10

View File

@@ -0,0 +1 @@
"""Integration tests that require the full stack to be running."""

View File

@@ -0,0 +1,322 @@
"""
Integration tests for MCP server connectivity.
These tests require the full stack to be running:
- docker compose -f docker-compose.dev.yml up
Run with:
pytest tests/integration/ -v --integration
Or skip with:
pytest tests/ -v --ignore=tests/integration/
"""
import os
from typing import Any
import httpx
import pytest
# Skip all tests in this module if not running integration tests
pytestmark = pytest.mark.skipif(
os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true",
reason="Integration tests require RUN_INTEGRATION_TESTS=true and running stack",
)
# Configuration from environment
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8001")
KNOWLEDGE_BASE_URL = os.getenv("KNOWLEDGE_BASE_URL", "http://localhost:8002")
class TestMCPServerHealth:
"""Test that MCP servers are healthy and reachable."""
@pytest.mark.asyncio
async def test_llm_gateway_health(self) -> None:
"""Test LLM Gateway health endpoint."""
async with httpx.AsyncClient() as client:
response = await client.get(f"{LLM_GATEWAY_URL}/health", timeout=10.0)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "healthy" or data.get("healthy") is True
@pytest.mark.asyncio
async def test_knowledge_base_health(self) -> None:
"""Test Knowledge Base health endpoint."""
async with httpx.AsyncClient() as client:
response = await client.get(f"{KNOWLEDGE_BASE_URL}/health", timeout=10.0)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "healthy" or data.get("healthy") is True
@pytest.mark.asyncio
async def test_backend_health(self) -> None:
"""Test Backend health endpoint."""
async with httpx.AsyncClient() as client:
response = await client.get(f"{BACKEND_URL}/health", timeout=10.0)
assert response.status_code == 200
class TestMCPClientManagerIntegration:
"""Test MCPClientManager can connect to real MCP servers."""
@pytest.mark.asyncio
async def test_mcp_servers_list(self) -> None:
"""Test that backend can list MCP servers via API."""
async with httpx.AsyncClient() as client:
# This endpoint lists configured MCP servers
response = await client.get(
f"{BACKEND_URL}/api/v1/mcp/servers",
timeout=10.0,
)
# Should return 200 or 401 (if auth required)
assert response.status_code in [200, 401, 403]
@pytest.mark.asyncio
async def test_mcp_health_check_endpoint(self) -> None:
"""Test backend's MCP health check endpoint."""
async with httpx.AsyncClient() as client:
response = await client.get(
f"{BACKEND_URL}/api/v1/mcp/health",
timeout=30.0, # MCP health checks can take time
)
# Should return 200 or 401 (if auth required)
if response.status_code == 200:
data = response.json()
# Check structure
assert "servers" in data or "healthy" in data
class TestLLMGatewayIntegration:
"""Test LLM Gateway MCP server functionality."""
@pytest.mark.asyncio
async def test_list_models(self) -> None:
"""Test that LLM Gateway can list available models."""
async with httpx.AsyncClient() as client:
# MCP servers use JSON-RPC 2.0 protocol at /mcp endpoint
response = await client.post(
f"{LLM_GATEWAY_URL}/mcp",
json={
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {},
},
timeout=10.0,
)
assert response.status_code == 200
data = response.json()
# Should have tools listed
assert "result" in data or "error" in data
@pytest.mark.asyncio
async def test_count_tokens(self) -> None:
"""Test token counting functionality."""
async with httpx.AsyncClient() as client:
response = await client.post(
f"{LLM_GATEWAY_URL}/mcp",
json={
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "count_tokens",
"arguments": {
"project_id": "test-project",
"agent_id": "test-agent",
"text": "Hello, world!",
},
},
},
timeout=10.0,
)
assert response.status_code == 200
data = response.json()
# Check for result or error
if "result" in data:
assert "content" in data["result"] or "token_count" in str(
data["result"]
)
class TestKnowledgeBaseIntegration:
"""Test Knowledge Base MCP server functionality."""
@pytest.mark.asyncio
async def test_list_tools(self) -> None:
"""Test that Knowledge Base can list available tools."""
async with httpx.AsyncClient() as client:
# Knowledge Base uses GET /mcp/tools for listing
response = await client.get(
f"{KNOWLEDGE_BASE_URL}/mcp/tools",
timeout=10.0,
)
assert response.status_code == 200
data = response.json()
assert "tools" in data
@pytest.mark.asyncio
async def test_search_knowledge_empty(self) -> None:
"""Test search on empty knowledge base."""
async with httpx.AsyncClient() as client:
# Knowledge Base uses direct tool name as method
response = await client.post(
f"{KNOWLEDGE_BASE_URL}/mcp",
json={
"jsonrpc": "2.0",
"id": 1,
"method": "search_knowledge",
"params": {
"project_id": "test-project",
"agent_id": "test-agent",
"query": "test query",
"limit": 5,
},
},
timeout=10.0,
)
assert response.status_code == 200
data = response.json()
# Should return empty results or error for no collection
assert "result" in data or "error" in data
class TestEndToEndMCPFlow:
"""End-to-end tests for MCP integration flow."""
@pytest.mark.asyncio
async def test_full_mcp_discovery_flow(self) -> None:
"""Test the full flow of discovering and listing MCP tools."""
async with httpx.AsyncClient() as client:
# 1. Check backend health
health = await client.get(f"{BACKEND_URL}/health", timeout=10.0)
assert health.status_code == 200
# 2. Check LLM Gateway health
llm_health = await client.get(f"{LLM_GATEWAY_URL}/health", timeout=10.0)
assert llm_health.status_code == 200
# 3. Check Knowledge Base health
kb_health = await client.get(f"{KNOWLEDGE_BASE_URL}/health", timeout=10.0)
assert kb_health.status_code == 200
# 4. List tools from LLM Gateway (uses JSON-RPC at /mcp)
llm_tools = await client.post(
f"{LLM_GATEWAY_URL}/mcp",
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}},
timeout=10.0,
)
assert llm_tools.status_code == 200
# 5. List tools from Knowledge Base (uses GET /mcp/tools)
kb_tools = await client.get(
f"{KNOWLEDGE_BASE_URL}/mcp/tools",
timeout=10.0,
)
assert kb_tools.status_code == 200
class TestContextEngineIntegration:
"""Test Context Engine integration with MCP servers."""
@pytest.mark.asyncio
async def test_context_health_endpoint(self) -> None:
"""Test context engine health endpoint."""
async with httpx.AsyncClient() as client:
response = await client.get(
f"{BACKEND_URL}/api/v1/context/health",
timeout=10.0,
)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "healthy"
assert "mcp_connected" in data
assert "cache_enabled" in data
@pytest.mark.asyncio
async def test_context_budget_endpoint(self) -> None:
"""Test token budget endpoint."""
async with httpx.AsyncClient() as client:
response = await client.get(
f"{BACKEND_URL}/api/v1/context/budget/claude-3-sonnet",
timeout=10.0,
)
assert response.status_code == 200
data = response.json()
assert "total_tokens" in data
assert "system_tokens" in data
assert data.get("model") == "claude-3-sonnet"
@pytest.mark.asyncio
async def test_context_assembly_requires_auth(self) -> None:
"""Test that context assembly requires authentication."""
async with httpx.AsyncClient() as client:
response = await client.post(
f"{BACKEND_URL}/api/v1/context/assemble",
json={
"project_id": "test-project",
"agent_id": "test-agent",
"query": "test query",
"model": "claude-3-sonnet",
},
timeout=10.0,
)
# Should require auth
assert response.status_code in [401, 403]
def run_quick_health_check() -> dict[str, Any]:
"""
Quick synchronous health check for all services.
Can be run standalone to verify the stack is up.
"""
import httpx
results: dict[str, Any] = {
"backend": False,
"llm_gateway": False,
"knowledge_base": False,
}
try:
with httpx.Client(timeout=5.0) as client:
try:
r = client.get(f"{BACKEND_URL}/health")
results["backend"] = r.status_code == 200
except Exception:
pass
try:
r = client.get(f"{LLM_GATEWAY_URL}/health")
results["llm_gateway"] = r.status_code == 200
except Exception:
pass
try:
r = client.get(f"{KNOWLEDGE_BASE_URL}/health")
results["knowledge_base"] = r.status_code == 200
except Exception:
pass
except Exception:
pass
return results
if __name__ == "__main__":
print("Checking service health...")
results = run_quick_health_check()
for service, healthy in results.items():
status = "OK" if healthy else "FAILED"
print(f" {service}: {status}")
all_healthy = all(results.values())
if all_healthy:
print("\nAll services healthy! Run integration tests with:")
print(" RUN_INTEGRATION_TESTS=true pytest tests/integration/ -v")
else:
print("\nSome services are not healthy. Start the stack with:")
print(" make dev")

View File

@@ -72,7 +72,7 @@ class TestContextSettings:
"""Test performance settings."""
settings = ContextSettings()
assert settings.max_assembly_time_ms == 100
assert settings.max_assembly_time_ms == 2000
assert settings.parallel_scoring is True
assert settings.max_parallel_scores == 10

View File

@@ -758,3 +758,136 @@ class TestBaseScorer:
# Boundaries
assert scorer.normalize_score(0.0) == 0.0
assert scorer.normalize_score(1.0) == 1.0
class TestCompositeScorerEdgeCases:
"""Tests for CompositeScorer edge cases and lock management."""
@pytest.mark.asyncio
async def test_score_with_zero_weights(self) -> None:
"""Test scoring when all weights are zero."""
scorer = CompositeScorer(
relevance_weight=0.0,
recency_weight=0.0,
priority_weight=0.0,
)
context = KnowledgeContext(
content="Test content",
source="docs",
relevance_score=0.8,
)
# Should return 0.0 when total weight is 0
score = await scorer.score(context, "test query")
assert score == 0.0
@pytest.mark.asyncio
async def test_score_batch_sequential(self) -> None:
"""Test batch scoring in sequential mode (parallel=False)."""
scorer = CompositeScorer()
contexts = [
KnowledgeContext(
content="Content 1",
source="docs",
relevance_score=0.8,
),
KnowledgeContext(
content="Content 2",
source="docs",
relevance_score=0.5,
),
]
# Use parallel=False to cover the sequential path
scored = await scorer.score_batch(contexts, "query", parallel=False)
assert len(scored) == 2
assert scored[0].relevance_score == 0.8
assert scored[1].relevance_score == 0.5
@pytest.mark.asyncio
async def test_lock_fast_path_reuse(self) -> None:
"""Test that existing locks are reused via fast path."""
scorer = CompositeScorer()
context = KnowledgeContext(
content="Test",
source="docs",
relevance_score=0.5,
)
# First access creates the lock
lock1 = await scorer._get_context_lock(context.id)
# Second access should hit the fast path (lock exists in dict)
lock2 = await scorer._get_context_lock(context.id)
assert lock2 is lock1 # Same lock object returned
@pytest.mark.asyncio
async def test_lock_cleanup_when_limit_reached(self) -> None:
"""Test that old locks are cleaned up when limit is reached."""
import time
# Create scorer with very low max_locks to trigger cleanup
scorer = CompositeScorer()
scorer._max_locks = 3
scorer._lock_ttl = 0.1 # 100ms TTL
# Create locks for several context IDs
context_ids = [f"ctx-{i}" for i in range(5)]
# Get locks for first 3 contexts (fill up to limit)
for ctx_id in context_ids[:3]:
await scorer._get_context_lock(ctx_id)
# Wait for TTL to expire
time.sleep(0.15)
# Getting a lock for a new context should trigger cleanup
await scorer._get_context_lock(context_ids[3])
# Some old locks should have been cleaned up
# The exact number depends on cleanup logic
assert len(scorer._context_locks) <= scorer._max_locks + 1
@pytest.mark.asyncio
async def test_lock_cleanup_preserves_held_locks(self) -> None:
"""Test that cleanup doesn't remove locks that are currently held."""
import time
scorer = CompositeScorer()
scorer._max_locks = 2
scorer._lock_ttl = 0.05 # 50ms TTL
# Get and hold lock1
lock1 = await scorer._get_context_lock("ctx-1")
async with lock1:
# While holding lock1, add more locks
await scorer._get_context_lock("ctx-2")
time.sleep(0.1) # Let TTL expire
# Adding another should trigger cleanup
await scorer._get_context_lock("ctx-3")
# lock1 should still exist (it's held)
assert any(lock is lock1 for lock, _ in scorer._context_locks.values())
@pytest.mark.asyncio
async def test_concurrent_lock_acquisition_double_check(self) -> None:
"""Test that concurrent lock acquisition uses double-check pattern."""
import asyncio
scorer = CompositeScorer()
context_id = "test-context-id"
# Simulate concurrent lock acquisition
async def get_lock():
return await scorer._get_context_lock(context_id)
locks = await asyncio.gather(*[get_lock() for _ in range(10)])
# All should get the same lock (double-check pattern ensures this)
assert all(lock is locks[0] for lock in locks)

View File

@@ -0,0 +1,989 @@
"""
Tests for Audit Logger.
Tests cover:
- AuditLogger initialization and lifecycle
- Event logging and hash chain
- Query and filtering
- Retention policy enforcement
- Handler management
- Singleton pattern
"""
import asyncio
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
import pytest_asyncio
from app.services.safety.audit.logger import (
AuditLogger,
get_audit_logger,
reset_audit_logger,
shutdown_audit_logger,
)
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
AuditEventType,
AutonomyLevel,
SafetyDecision,
)
class TestAuditLoggerInit:
"""Tests for AuditLogger initialization."""
def test_init_default_values(self):
"""Test initialization with default values."""
logger = AuditLogger()
assert logger._flush_interval == 10.0
assert logger._enable_hash_chain is True
assert logger._last_hash is None
assert logger._running is False
def test_init_custom_values(self):
"""Test initialization with custom values."""
logger = AuditLogger(
max_buffer_size=500,
flush_interval_seconds=5.0,
enable_hash_chain=False,
)
assert logger._flush_interval == 5.0
assert logger._enable_hash_chain is False
class TestAuditLoggerLifecycle:
"""Tests for AuditLogger start/stop."""
@pytest.mark.asyncio
async def test_start_creates_flush_task(self):
"""Test that start creates the periodic flush task."""
logger = AuditLogger(flush_interval_seconds=1.0)
await logger.start()
assert logger._running is True
assert logger._flush_task is not None
await logger.stop()
@pytest.mark.asyncio
async def test_start_idempotent(self):
"""Test that multiple starts don't create multiple tasks."""
logger = AuditLogger()
await logger.start()
task1 = logger._flush_task
await logger.start() # Second start
task2 = logger._flush_task
assert task1 is task2
await logger.stop()
@pytest.mark.asyncio
async def test_stop_cancels_task_and_flushes(self):
"""Test that stop cancels the task and flushes events."""
logger = AuditLogger()
await logger.start()
# Add an event
await logger.log(AuditEventType.ACTION_REQUESTED, agent_id="agent-1")
await logger.stop()
assert logger._running is False
# Event should be flushed
assert len(logger._persisted) == 1
@pytest.mark.asyncio
async def test_stop_without_start(self):
"""Test stopping without starting doesn't error."""
logger = AuditLogger()
await logger.stop() # Should not raise
class TestAuditLoggerLog:
"""Tests for the log method."""
@pytest_asyncio.fixture
async def logger(self):
"""Create a logger instance."""
return AuditLogger(enable_hash_chain=True)
@pytest.mark.asyncio
async def test_log_creates_event(self, logger):
"""Test logging creates an event."""
event = await logger.log(
AuditEventType.ACTION_REQUESTED,
agent_id="agent-1",
project_id="proj-1",
)
assert event.event_type == AuditEventType.ACTION_REQUESTED
assert event.agent_id == "agent-1"
assert event.project_id == "proj-1"
assert event.id is not None
assert event.timestamp is not None
@pytest.mark.asyncio
async def test_log_adds_hash_chain(self, logger):
"""Test logging adds hash chain."""
event = await logger.log(
AuditEventType.ACTION_REQUESTED,
agent_id="agent-1",
)
assert "_hash" in event.details
assert "_prev_hash" in event.details
assert event.details["_prev_hash"] is None # First event
@pytest.mark.asyncio
async def test_log_chain_links_events(self, logger):
"""Test hash chain links events."""
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
event2 = await logger.log(AuditEventType.ACTION_EXECUTED)
assert event2.details["_prev_hash"] == event1.details["_hash"]
@pytest.mark.asyncio
async def test_log_without_hash_chain(self):
"""Test logging without hash chain."""
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(AuditEventType.ACTION_REQUESTED)
assert "_hash" not in event.details
assert "_prev_hash" not in event.details
@pytest.mark.asyncio
async def test_log_with_all_fields(self, logger):
"""Test logging with all optional fields."""
event = await logger.log(
AuditEventType.ACTION_EXECUTED,
agent_id="agent-1",
action_id="action-1",
project_id="proj-1",
session_id="sess-1",
user_id="user-1",
decision=SafetyDecision.ALLOW,
details={"custom": "data"},
correlation_id="corr-1",
)
assert event.agent_id == "agent-1"
assert event.action_id == "action-1"
assert event.project_id == "proj-1"
assert event.session_id == "sess-1"
assert event.user_id == "user-1"
assert event.decision == SafetyDecision.ALLOW
assert event.details["custom"] == "data"
assert event.correlation_id == "corr-1"
@pytest.mark.asyncio
async def test_log_buffers_event(self, logger):
"""Test logging adds event to buffer."""
await logger.log(AuditEventType.ACTION_REQUESTED)
assert len(logger._buffer) == 1
class TestAuditLoggerConvenienceMethods:
"""Tests for convenience logging methods."""
@pytest_asyncio.fixture
async def logger(self):
"""Create a logger instance."""
return AuditLogger(enable_hash_chain=False)
@pytest_asyncio.fixture
def action(self):
"""Create a test action request."""
metadata = ActionMetadata(
agent_id="agent-1",
session_id="sess-1",
project_id="proj-1",
autonomy_level=AutonomyLevel.MILESTONE,
user_id="user-1",
correlation_id="corr-1",
)
return ActionRequest(
action_type=ActionType.FILE_WRITE,
tool_name="file_write",
arguments={"path": "/test.txt"},
resource="/test.txt",
metadata=metadata,
)
@pytest.mark.asyncio
async def test_log_action_request_allowed(self, logger, action):
"""Test logging allowed action request."""
event = await logger.log_action_request(
action,
SafetyDecision.ALLOW,
reasons=["Within budget"],
)
assert event.event_type == AuditEventType.ACTION_VALIDATED
assert event.decision == SafetyDecision.ALLOW
assert event.details["reasons"] == ["Within budget"]
@pytest.mark.asyncio
async def test_log_action_request_denied(self, logger, action):
"""Test logging denied action request."""
event = await logger.log_action_request(
action,
SafetyDecision.DENY,
reasons=["Rate limit exceeded"],
)
assert event.event_type == AuditEventType.ACTION_DENIED
assert event.decision == SafetyDecision.DENY
@pytest.mark.asyncio
async def test_log_action_executed_success(self, logger, action):
"""Test logging successful action execution."""
event = await logger.log_action_executed(
action,
success=True,
execution_time_ms=50.0,
)
assert event.event_type == AuditEventType.ACTION_EXECUTED
assert event.details["success"] is True
assert event.details["execution_time_ms"] == 50.0
assert event.details["error"] is None
@pytest.mark.asyncio
async def test_log_action_executed_failure(self, logger, action):
"""Test logging failed action execution."""
event = await logger.log_action_executed(
action,
success=False,
execution_time_ms=100.0,
error="File not found",
)
assert event.event_type == AuditEventType.ACTION_FAILED
assert event.details["success"] is False
assert event.details["error"] == "File not found"
@pytest.mark.asyncio
async def test_log_approval_event(self, logger, action):
"""Test logging approval event."""
event = await logger.log_approval_event(
AuditEventType.APPROVAL_GRANTED,
approval_id="approval-1",
action=action,
decided_by="admin",
reason="Approved by admin",
)
assert event.event_type == AuditEventType.APPROVAL_GRANTED
assert event.details["approval_id"] == "approval-1"
assert event.details["decided_by"] == "admin"
@pytest.mark.asyncio
async def test_log_budget_event(self, logger):
"""Test logging budget event."""
event = await logger.log_budget_event(
AuditEventType.BUDGET_WARNING,
agent_id="agent-1",
scope="daily",
current_usage=8000.0,
limit=10000.0,
unit="tokens",
)
assert event.event_type == AuditEventType.BUDGET_WARNING
assert event.details["scope"] == "daily"
assert event.details["usage_percent"] == 80.0
@pytest.mark.asyncio
async def test_log_budget_event_zero_limit(self, logger):
"""Test logging budget event with zero limit."""
event = await logger.log_budget_event(
AuditEventType.BUDGET_WARNING,
agent_id="agent-1",
scope="daily",
current_usage=100.0,
limit=0.0, # Zero limit
)
assert event.details["usage_percent"] == 0
@pytest.mark.asyncio
async def test_log_emergency_stop(self, logger):
"""Test logging emergency stop."""
event = await logger.log_emergency_stop(
stop_type="global",
triggered_by="admin",
reason="Security incident",
affected_agents=["agent-1", "agent-2"],
)
assert event.event_type == AuditEventType.EMERGENCY_STOP
assert event.details["stop_type"] == "global"
assert event.details["affected_agents"] == ["agent-1", "agent-2"]
class TestAuditLoggerFlush:
"""Tests for flush functionality."""
@pytest.mark.asyncio
async def test_flush_persists_events(self):
"""Test flush moves events to persisted storage."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
await logger.log(AuditEventType.ACTION_EXECUTED)
assert len(logger._buffer) == 2
assert len(logger._persisted) == 0
count = await logger.flush()
assert count == 2
assert len(logger._buffer) == 0
assert len(logger._persisted) == 2
@pytest.mark.asyncio
async def test_flush_empty_buffer(self):
"""Test flush with empty buffer."""
logger = AuditLogger()
count = await logger.flush()
assert count == 0
class TestAuditLoggerQuery:
"""Tests for query functionality."""
@pytest_asyncio.fixture
async def logger_with_events(self):
"""Create a logger with some test events."""
logger = AuditLogger(enable_hash_chain=False)
# Add various events
await logger.log(
AuditEventType.ACTION_REQUESTED,
agent_id="agent-1",
project_id="proj-1",
)
await logger.log(
AuditEventType.ACTION_EXECUTED,
agent_id="agent-1",
project_id="proj-1",
)
await logger.log(
AuditEventType.ACTION_DENIED,
agent_id="agent-2",
project_id="proj-2",
)
await logger.log(
AuditEventType.BUDGET_WARNING,
agent_id="agent-1",
project_id="proj-1",
user_id="user-1",
)
return logger
@pytest.mark.asyncio
async def test_query_all(self, logger_with_events):
"""Test querying all events."""
events = await logger_with_events.query()
assert len(events) == 4
@pytest.mark.asyncio
async def test_query_by_event_type(self, logger_with_events):
"""Test filtering by event type."""
events = await logger_with_events.query(
event_types=[AuditEventType.ACTION_REQUESTED]
)
assert len(events) == 1
assert events[0].event_type == AuditEventType.ACTION_REQUESTED
@pytest.mark.asyncio
async def test_query_by_agent_id(self, logger_with_events):
"""Test filtering by agent ID."""
events = await logger_with_events.query(agent_id="agent-1")
assert len(events) == 3
assert all(e.agent_id == "agent-1" for e in events)
@pytest.mark.asyncio
async def test_query_by_project_id(self, logger_with_events):
"""Test filtering by project ID."""
events = await logger_with_events.query(project_id="proj-2")
assert len(events) == 1
@pytest.mark.asyncio
async def test_query_by_user_id(self, logger_with_events):
"""Test filtering by user ID."""
events = await logger_with_events.query(user_id="user-1")
assert len(events) == 1
assert events[0].event_type == AuditEventType.BUDGET_WARNING
@pytest.mark.asyncio
async def test_query_with_limit(self, logger_with_events):
"""Test query with limit."""
events = await logger_with_events.query(limit=2)
assert len(events) == 2
@pytest.mark.asyncio
async def test_query_with_offset(self, logger_with_events):
"""Test query with offset."""
all_events = await logger_with_events.query()
offset_events = await logger_with_events.query(offset=2)
assert len(offset_events) == 2
assert offset_events[0] == all_events[2]
@pytest.mark.asyncio
async def test_query_by_time_range(self):
"""Test filtering by time range."""
logger = AuditLogger(enable_hash_chain=False)
now = datetime.utcnow()
await logger.log(AuditEventType.ACTION_REQUESTED)
# Query with start time
events = await logger.query(
start_time=now - timedelta(seconds=1),
end_time=now + timedelta(seconds=1),
)
assert len(events) == 1
@pytest.mark.asyncio
async def test_query_by_correlation_id(self):
"""Test filtering by correlation ID."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(
AuditEventType.ACTION_REQUESTED,
correlation_id="corr-123",
)
await logger.log(
AuditEventType.ACTION_EXECUTED,
correlation_id="corr-456",
)
events = await logger.query(correlation_id="corr-123")
assert len(events) == 1
assert events[0].correlation_id == "corr-123"
@pytest.mark.asyncio
async def test_query_combined_filters(self, logger_with_events):
"""Test combined filters."""
events = await logger_with_events.query(
agent_id="agent-1",
project_id="proj-1",
event_types=[
AuditEventType.ACTION_REQUESTED,
AuditEventType.ACTION_EXECUTED,
],
)
assert len(events) == 2
@pytest.mark.asyncio
async def test_get_action_history(self, logger_with_events):
"""Test get_action_history method."""
events = await logger_with_events.get_action_history("agent-1")
# Should only return action-related events
assert len(events) == 2
assert all(
e.event_type
in {AuditEventType.ACTION_REQUESTED, AuditEventType.ACTION_EXECUTED}
for e in events
)
class TestAuditLoggerIntegrity:
"""Tests for hash chain integrity verification."""
@pytest.mark.asyncio
async def test_verify_integrity_valid(self):
"""Test integrity verification with valid chain."""
logger = AuditLogger(enable_hash_chain=True)
await logger.log(AuditEventType.ACTION_REQUESTED)
await logger.log(AuditEventType.ACTION_EXECUTED)
is_valid, issues = await logger.verify_integrity()
assert is_valid is True
assert len(issues) == 0
@pytest.mark.asyncio
async def test_verify_integrity_disabled(self):
"""Test integrity verification when hash chain disabled."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
is_valid, issues = await logger.verify_integrity()
assert is_valid is True
assert len(issues) == 0
@pytest.mark.asyncio
async def test_verify_integrity_broken_chain(self):
"""Test integrity verification with broken chain."""
logger = AuditLogger(enable_hash_chain=True)
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
await logger.log(AuditEventType.ACTION_EXECUTED)
# Tamper with first event's hash
event1.details["_hash"] = "tampered_hash"
is_valid, issues = await logger.verify_integrity()
assert is_valid is False
assert len(issues) > 0
class TestAuditLoggerHandlers:
"""Tests for event handler management."""
@pytest.mark.asyncio
async def test_add_sync_handler(self):
"""Test adding synchronous handler."""
logger = AuditLogger(enable_hash_chain=False)
events_received = []
def handler(event):
events_received.append(event)
logger.add_handler(handler)
await logger.log(AuditEventType.ACTION_REQUESTED)
assert len(events_received) == 1
@pytest.mark.asyncio
async def test_add_async_handler(self):
"""Test adding async handler."""
logger = AuditLogger(enable_hash_chain=False)
events_received = []
async def handler(event):
events_received.append(event)
logger.add_handler(handler)
await logger.log(AuditEventType.ACTION_REQUESTED)
assert len(events_received) == 1
@pytest.mark.asyncio
async def test_remove_handler(self):
"""Test removing handler."""
logger = AuditLogger(enable_hash_chain=False)
events_received = []
def handler(event):
events_received.append(event)
logger.add_handler(handler)
await logger.log(AuditEventType.ACTION_REQUESTED)
logger.remove_handler(handler)
await logger.log(AuditEventType.ACTION_EXECUTED)
assert len(events_received) == 1
@pytest.mark.asyncio
async def test_handler_error_caught(self):
"""Test that handler errors are caught."""
logger = AuditLogger(enable_hash_chain=False)
def failing_handler(event):
raise ValueError("Handler error")
logger.add_handler(failing_handler)
# Should not raise
event = await logger.log(AuditEventType.ACTION_REQUESTED)
assert event is not None
class TestAuditLoggerSanitization:
"""Tests for sensitive data sanitization."""
@pytest.mark.asyncio
async def test_sanitize_sensitive_keys(self):
"""Test sanitization of sensitive keys."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 30
mock_cfg.audit_include_sensitive = False
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(
AuditEventType.ACTION_EXECUTED,
details={
"password": "secret123",
"api_key": "key123",
"token": "token123",
"normal_field": "visible",
},
)
assert event.details["password"] == "[REDACTED]"
assert event.details["api_key"] == "[REDACTED]"
assert event.details["token"] == "[REDACTED]"
assert event.details["normal_field"] == "visible"
@pytest.mark.asyncio
async def test_sanitize_nested_dict(self):
"""Test sanitization of nested dictionaries."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 30
mock_cfg.audit_include_sensitive = False
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(
AuditEventType.ACTION_EXECUTED,
details={
"config": {
"api_secret": "secret",
"name": "test",
}
},
)
assert event.details["config"]["api_secret"] == "[REDACTED]"
assert event.details["config"]["name"] == "test"
@pytest.mark.asyncio
async def test_include_sensitive_when_enabled(self):
"""Test sensitive data included when enabled."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 30
mock_cfg.audit_include_sensitive = True
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(
AuditEventType.ACTION_EXECUTED,
details={"password": "secret123"},
)
assert event.details["password"] == "secret123"
class TestAuditLoggerRetention:
"""Tests for retention policy enforcement."""
@pytest.mark.asyncio
async def test_retention_removes_old_events(self):
"""Test that retention removes old events."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 7
mock_cfg.audit_include_sensitive = False
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
# Add an old event directly to persisted
from app.services.safety.models import AuditEvent
old_event = AuditEvent(
id="old-event",
event_type=AuditEventType.ACTION_REQUESTED,
timestamp=datetime.utcnow() - timedelta(days=10),
details={},
)
logger._persisted.append(old_event)
# Add a recent event
await logger.log(AuditEventType.ACTION_EXECUTED)
# Flush will trigger retention enforcement
await logger.flush()
# Old event should be removed
assert len(logger._persisted) == 1
assert logger._persisted[0].id != "old-event"
@pytest.mark.asyncio
async def test_retention_keeps_recent_events(self):
"""Test that retention keeps recent events."""
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
mock_cfg = MagicMock()
mock_cfg.audit_retention_days = 7
mock_cfg.audit_include_sensitive = False
mock_config.return_value = mock_cfg
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
await logger.log(AuditEventType.ACTION_EXECUTED)
await logger.flush()
assert len(logger._persisted) == 2
class TestAuditLoggerSingleton:
"""Tests for singleton pattern."""
@pytest.mark.asyncio
async def test_get_audit_logger_creates_instance(self):
"""Test get_audit_logger creates singleton."""
reset_audit_logger()
logger1 = await get_audit_logger()
logger2 = await get_audit_logger()
assert logger1 is logger2
await shutdown_audit_logger()
@pytest.mark.asyncio
async def test_shutdown_audit_logger(self):
"""Test shutdown_audit_logger stops and clears singleton."""
import app.services.safety.audit.logger as audit_module
reset_audit_logger()
_logger = await get_audit_logger()
await shutdown_audit_logger()
assert audit_module._audit_logger is None
def test_reset_audit_logger(self):
"""Test reset_audit_logger clears singleton."""
import app.services.safety.audit.logger as audit_module
audit_module._audit_logger = AuditLogger()
reset_audit_logger()
assert audit_module._audit_logger is None
class TestAuditLoggerPeriodicFlush:
"""Tests for periodic flush background task."""
@pytest.mark.asyncio
async def test_periodic_flush_runs(self):
"""Test periodic flush runs and flushes events."""
logger = AuditLogger(flush_interval_seconds=0.1, enable_hash_chain=False)
await logger.start()
# Log an event
await logger.log(AuditEventType.ACTION_REQUESTED)
assert len(logger._buffer) == 1
# Wait for periodic flush
await asyncio.sleep(0.15)
# Event should be flushed
assert len(logger._buffer) == 0
assert len(logger._persisted) == 1
await logger.stop()
@pytest.mark.asyncio
async def test_periodic_flush_handles_errors(self):
"""Test periodic flush handles errors gracefully."""
logger = AuditLogger(flush_interval_seconds=0.1)
await logger.start()
# Mock flush to raise an error
original_flush = logger.flush
async def failing_flush():
raise Exception("Flush error")
logger.flush = failing_flush
# Wait for flush attempt
await asyncio.sleep(0.15)
# Should still be running
assert logger._running is True
logger.flush = original_flush
await logger.stop()
class TestAuditLoggerLogging:
"""Tests for standard logger output."""
@pytest.mark.asyncio
async def test_log_warning_for_denied(self):
"""Test warning level for denied events."""
with patch("app.services.safety.audit.logger.logger") as mock_logger:
audit_logger = AuditLogger(enable_hash_chain=False)
await audit_logger.log(
AuditEventType.ACTION_DENIED,
agent_id="agent-1",
)
mock_logger.warning.assert_called()
@pytest.mark.asyncio
async def test_log_error_for_failed(self):
"""Test error level for failed events."""
with patch("app.services.safety.audit.logger.logger") as mock_logger:
audit_logger = AuditLogger(enable_hash_chain=False)
await audit_logger.log(
AuditEventType.ACTION_FAILED,
agent_id="agent-1",
)
mock_logger.error.assert_called()
@pytest.mark.asyncio
async def test_log_info_for_normal(self):
"""Test info level for normal events."""
with patch("app.services.safety.audit.logger.logger") as mock_logger:
audit_logger = AuditLogger(enable_hash_chain=False)
await audit_logger.log(
AuditEventType.ACTION_EXECUTED,
agent_id="agent-1",
)
mock_logger.info.assert_called()
class TestAuditLoggerEdgeCases:
"""Tests for edge cases."""
@pytest.mark.asyncio
async def test_log_with_none_details(self):
"""Test logging with None details."""
logger = AuditLogger(enable_hash_chain=False)
event = await logger.log(
AuditEventType.ACTION_REQUESTED,
details=None,
)
assert event.details == {}
@pytest.mark.asyncio
async def test_query_with_action_id(self):
"""Test querying by action ID."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(
AuditEventType.ACTION_REQUESTED,
action_id="action-1",
)
await logger.log(
AuditEventType.ACTION_EXECUTED,
action_id="action-2",
)
events = await logger.query(action_id="action-1")
assert len(events) == 1
assert events[0].action_id == "action-1"
@pytest.mark.asyncio
async def test_query_with_session_id(self):
"""Test querying by session ID."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(
AuditEventType.ACTION_REQUESTED,
session_id="sess-1",
)
await logger.log(
AuditEventType.ACTION_EXECUTED,
session_id="sess-2",
)
events = await logger.query(session_id="sess-1")
assert len(events) == 1
@pytest.mark.asyncio
async def test_query_includes_buffer_and_persisted(self):
"""Test query includes both buffer and persisted events."""
logger = AuditLogger(enable_hash_chain=False)
# Add event to buffer
await logger.log(AuditEventType.ACTION_REQUESTED)
# Flush to persisted
await logger.flush()
# Add another to buffer
await logger.log(AuditEventType.ACTION_EXECUTED)
# Query should return both
events = await logger.query()
assert len(events) == 2
@pytest.mark.asyncio
async def test_remove_nonexistent_handler(self):
"""Test removing handler that doesn't exist."""
logger = AuditLogger()
def handler(event):
pass
# Should not raise
logger.remove_handler(handler)
@pytest.mark.asyncio
async def test_query_time_filter_excludes_events(self):
"""Test time filters exclude events correctly."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
# Query with future start time
future = datetime.utcnow() + timedelta(hours=1)
events = await logger.query(start_time=future)
assert len(events) == 0
@pytest.mark.asyncio
async def test_query_end_time_filter(self):
"""Test end time filter."""
logger = AuditLogger(enable_hash_chain=False)
await logger.log(AuditEventType.ACTION_REQUESTED)
# Query with past end time
past = datetime.utcnow() - timedelta(hours=1)
events = await logger.query(end_time=past)
assert len(events) == 0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,874 @@
"""
Tests for MCP Safety Integration.
Tests cover:
- MCPToolCall and MCPToolResult data structures
- MCPSafetyWrapper: tool registration, execution, safety checks
- Tool classification and action type mapping
- SafeToolExecutor context manager
- Factory function create_mcp_wrapper
"""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from app.services.safety.exceptions import EmergencyStopError
from app.services.safety.mcp.integration import (
MCPSafetyWrapper,
MCPToolCall,
MCPToolResult,
SafeToolExecutor,
create_mcp_wrapper,
)
from app.services.safety.models import (
ActionType,
AutonomyLevel,
SafetyDecision,
)
class TestMCPToolCall:
"""Tests for MCPToolCall dataclass."""
def test_tool_call_creation(self):
"""Test creating a tool call."""
call = MCPToolCall(
tool_name="file_read",
arguments={"path": "/tmp/test.txt"}, # noqa: S108
server_name="file-server",
project_id="proj-1",
context={"session_id": "sess-1"},
)
assert call.tool_name == "file_read"
assert call.arguments == {"path": "/tmp/test.txt"} # noqa: S108
assert call.server_name == "file-server"
assert call.project_id == "proj-1"
assert call.context == {"session_id": "sess-1"}
def test_tool_call_defaults(self):
"""Test tool call default values."""
call = MCPToolCall(
tool_name="test",
arguments={},
)
assert call.server_name is None
assert call.project_id is None
assert call.context == {}
class TestMCPToolResult:
"""Tests for MCPToolResult dataclass."""
def test_tool_result_success(self):
"""Test creating a successful result."""
result = MCPToolResult(
success=True,
result={"data": "test"},
safety_decision=SafetyDecision.ALLOW,
execution_time_ms=50.0,
)
assert result.success is True
assert result.result == {"data": "test"}
assert result.error is None
assert result.safety_decision == SafetyDecision.ALLOW
assert result.execution_time_ms == 50.0
def test_tool_result_failure(self):
"""Test creating a failed result."""
result = MCPToolResult(
success=False,
error="Permission denied",
safety_decision=SafetyDecision.DENY,
)
assert result.success is False
assert result.error == "Permission denied"
assert result.result is None
def test_tool_result_with_ids(self):
"""Test result with approval and checkpoint IDs."""
result = MCPToolResult(
success=True,
approval_id="approval-123",
checkpoint_id="checkpoint-456",
)
assert result.approval_id == "approval-123"
assert result.checkpoint_id == "checkpoint-456"
def test_tool_result_defaults(self):
"""Test result default values."""
result = MCPToolResult(success=True)
assert result.result is None
assert result.error is None
assert result.safety_decision == SafetyDecision.ALLOW
assert result.execution_time_ms == 0.0
assert result.approval_id is None
assert result.checkpoint_id is None
assert result.metadata == {}
class TestMCPSafetyWrapperClassification:
"""Tests for tool classification."""
def test_classify_file_read(self):
"""Test classifying file read tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("file_read") == ActionType.FILE_READ
assert wrapper._classify_tool("get_file") == ActionType.FILE_READ
assert wrapper._classify_tool("list_files") == ActionType.FILE_READ
assert wrapper._classify_tool("search_file") == ActionType.FILE_READ
def test_classify_file_write(self):
"""Test classifying file write tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("file_write") == ActionType.FILE_WRITE
assert wrapper._classify_tool("create_file") == ActionType.FILE_WRITE
assert wrapper._classify_tool("update_file") == ActionType.FILE_WRITE
def test_classify_file_delete(self):
"""Test classifying file delete tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("file_delete") == ActionType.FILE_DELETE
assert wrapper._classify_tool("remove_file") == ActionType.FILE_DELETE
def test_classify_database_read(self):
"""Test classifying database read tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("database_query") == ActionType.DATABASE_QUERY
assert wrapper._classify_tool("db_read") == ActionType.DATABASE_QUERY
assert wrapper._classify_tool("query_database") == ActionType.DATABASE_QUERY
def test_classify_database_mutate(self):
"""Test classifying database mutate tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("database_write") == ActionType.DATABASE_MUTATE
assert wrapper._classify_tool("db_update") == ActionType.DATABASE_MUTATE
assert wrapper._classify_tool("database_delete") == ActionType.DATABASE_MUTATE
def test_classify_shell_command(self):
"""Test classifying shell command tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("shell_execute") == ActionType.SHELL_COMMAND
assert wrapper._classify_tool("exec_command") == ActionType.SHELL_COMMAND
assert wrapper._classify_tool("bash_run") == ActionType.SHELL_COMMAND
def test_classify_git_operation(self):
"""Test classifying git tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("git_commit") == ActionType.GIT_OPERATION
assert wrapper._classify_tool("git_push") == ActionType.GIT_OPERATION
assert wrapper._classify_tool("git_status") == ActionType.GIT_OPERATION
def test_classify_network_request(self):
"""Test classifying network tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("http_get") == ActionType.NETWORK_REQUEST
assert wrapper._classify_tool("fetch_url") == ActionType.NETWORK_REQUEST
assert wrapper._classify_tool("api_request") == ActionType.NETWORK_REQUEST
def test_classify_llm_call(self):
"""Test classifying LLM tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("llm_generate") == ActionType.LLM_CALL
assert wrapper._classify_tool("ai_complete") == ActionType.LLM_CALL
assert wrapper._classify_tool("claude_chat") == ActionType.LLM_CALL
def test_classify_default(self):
"""Test default classification for unknown tools."""
wrapper = MCPSafetyWrapper()
assert wrapper._classify_tool("unknown_tool") == ActionType.TOOL_CALL
assert wrapper._classify_tool("custom_action") == ActionType.TOOL_CALL
class TestMCPSafetyWrapperToolHandlers:
"""Tests for tool handler registration."""
def test_register_tool_handler(self):
"""Test registering a tool handler."""
wrapper = MCPSafetyWrapper()
def handler(path: str) -> str:
return f"Read: {path}"
wrapper.register_tool_handler("file_read", handler)
assert "file_read" in wrapper._tool_handlers
assert wrapper._tool_handlers["file_read"] is handler
def test_register_multiple_handlers(self):
"""Test registering multiple handlers."""
wrapper = MCPSafetyWrapper()
wrapper.register_tool_handler("tool1", lambda: None)
wrapper.register_tool_handler("tool2", lambda: None)
wrapper.register_tool_handler("tool3", lambda: None)
assert len(wrapper._tool_handlers) == 3
def test_overwrite_handler(self):
"""Test overwriting a handler."""
wrapper = MCPSafetyWrapper()
handler1 = lambda: "first" # noqa: E731
handler2 = lambda: "second" # noqa: E731
wrapper.register_tool_handler("tool", handler1)
wrapper.register_tool_handler("tool", handler2)
assert wrapper._tool_handlers["tool"] is handler2
class TestMCPSafetyWrapperExecution:
"""Tests for tool execution."""
@pytest_asyncio.fixture
async def mock_guardian(self):
"""Create a mock SafetyGuardian."""
guardian = AsyncMock()
guardian.validate = AsyncMock()
return guardian
@pytest_asyncio.fixture
async def mock_emergency(self):
"""Create a mock EmergencyControls."""
emergency = AsyncMock()
emergency.check_allowed = AsyncMock()
return emergency
@pytest.mark.asyncio
async def test_execute_allowed(self, mock_guardian, mock_emergency):
"""Test executing an allowed tool call."""
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id=None,
)
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
async def handler(path: str) -> dict:
return {"content": f"Data from {path}"}
wrapper.register_tool_handler("file_read", handler)
call = MCPToolCall(
tool_name="file_read",
arguments={"path": "/test.txt"},
project_id="proj-1",
)
result = await wrapper.execute(call, "agent-1")
assert result.success is True
assert result.result == {"content": "Data from /test.txt"}
assert result.safety_decision == SafetyDecision.ALLOW
@pytest.mark.asyncio
async def test_execute_denied(self, mock_guardian, mock_emergency):
"""Test executing a denied tool call."""
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.DENY,
reasons=["Permission denied", "Rate limit exceeded"],
)
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
call = MCPToolCall(
tool_name="file_write",
arguments={"path": "/etc/passwd"},
)
result = await wrapper.execute(call, "agent-1")
assert result.success is False
assert "Permission denied" in result.error
assert "Rate limit exceeded" in result.error
assert result.safety_decision == SafetyDecision.DENY
@pytest.mark.asyncio
async def test_execute_requires_approval(self, mock_guardian, mock_emergency):
"""Test executing a tool that requires approval."""
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.REQUIRE_APPROVAL,
reasons=["Destructive operation requires approval"],
approval_id="approval-123",
)
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
call = MCPToolCall(
tool_name="file_delete",
arguments={"path": "/important.txt"},
)
result = await wrapper.execute(call, "agent-1")
assert result.success is False
assert result.safety_decision == SafetyDecision.REQUIRE_APPROVAL
assert result.approval_id == "approval-123"
assert "requires human approval" in result.error
@pytest.mark.asyncio
async def test_execute_emergency_stop(self, mock_guardian, mock_emergency):
"""Test execution blocked by emergency stop."""
mock_emergency.check_allowed.side_effect = EmergencyStopError(
"Emergency stop active"
)
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
call = MCPToolCall(
tool_name="file_write",
arguments={"path": "/test.txt"},
project_id="proj-1",
)
result = await wrapper.execute(call, "agent-1")
assert result.success is False
assert result.safety_decision == SafetyDecision.DENY
assert result.metadata.get("emergency_stop") is True
@pytest.mark.asyncio
async def test_execute_bypass_safety(self, mock_guardian, mock_emergency):
"""Test executing with safety bypass."""
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
async def handler(data: str) -> str:
return f"Processed: {data}"
wrapper.register_tool_handler("custom_tool", handler)
call = MCPToolCall(
tool_name="custom_tool",
arguments={"data": "test"},
)
result = await wrapper.execute(call, "agent-1", bypass_safety=True)
assert result.success is True
assert result.result == "Processed: test"
# Guardian should not be called when bypassing
mock_guardian.validate.assert_not_called()
@pytest.mark.asyncio
async def test_execute_no_handler(self, mock_guardian, mock_emergency):
"""Test executing a tool with no registered handler."""
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id=None,
)
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
call = MCPToolCall(
tool_name="unregistered_tool",
arguments={},
)
result = await wrapper.execute(call, "agent-1")
assert result.success is False
assert "No handler registered" in result.error
@pytest.mark.asyncio
async def test_execute_handler_exception(self, mock_guardian, mock_emergency):
"""Test handling exceptions from tool handler."""
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id=None,
)
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
async def failing_handler() -> None:
raise ValueError("Handler failed!")
wrapper.register_tool_handler("failing_tool", failing_handler)
call = MCPToolCall(
tool_name="failing_tool",
arguments={},
)
result = await wrapper.execute(call, "agent-1")
assert result.success is False
assert "Handler failed!" in result.error
# Decision is still ALLOW because the safety check passed
assert result.safety_decision == SafetyDecision.ALLOW
@pytest.mark.asyncio
async def test_execute_sync_handler(self, mock_guardian, mock_emergency):
"""Test executing a synchronous handler."""
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id=None,
)
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
def sync_handler(value: int) -> int:
return value * 2
wrapper.register_tool_handler("sync_tool", sync_handler)
call = MCPToolCall(
tool_name="sync_tool",
arguments={"value": 21},
)
result = await wrapper.execute(call, "agent-1")
assert result.success is True
assert result.result == 42
class TestBuildActionRequest:
"""Tests for _build_action_request."""
def test_build_action_request_basic(self):
"""Test building a basic action request."""
wrapper = MCPSafetyWrapper()
call = MCPToolCall(
tool_name="file_read",
arguments={"path": "/test.txt"},
project_id="proj-1",
)
action = wrapper._build_action_request(call, "agent-1", AutonomyLevel.MILESTONE)
assert action.action_type == ActionType.FILE_READ
assert action.tool_name == "file_read"
assert action.arguments == {"path": "/test.txt"}
assert action.resource == "/test.txt"
assert action.metadata.agent_id == "agent-1"
assert action.metadata.project_id == "proj-1"
assert action.metadata.autonomy_level == AutonomyLevel.MILESTONE
def test_build_action_request_with_context(self):
"""Test building action request with session context."""
wrapper = MCPSafetyWrapper()
call = MCPToolCall(
tool_name="database_query",
arguments={"resource": "users", "query": "SELECT *"},
context={"session_id": "sess-123"},
project_id="proj-2",
)
action = wrapper._build_action_request(
call, "agent-2", AutonomyLevel.AUTONOMOUS
)
assert action.resource == "users"
assert action.metadata.session_id == "sess-123"
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
def test_build_action_request_no_resource(self):
"""Test building action request without resource."""
wrapper = MCPSafetyWrapper()
call = MCPToolCall(
tool_name="llm_generate",
arguments={"prompt": "Hello"},
)
action = wrapper._build_action_request(
call, "agent-1", AutonomyLevel.FULL_CONTROL
)
assert action.resource is None
class TestElapsedTime:
"""Tests for _elapsed_ms helper."""
def test_elapsed_ms(self):
"""Test calculating elapsed time."""
wrapper = MCPSafetyWrapper()
start = datetime.utcnow() - timedelta(milliseconds=100)
elapsed = wrapper._elapsed_ms(start)
# Should be at least 100ms, but allow some tolerance
assert elapsed >= 99
assert elapsed < 200
class TestSafeToolExecutor:
"""Tests for SafeToolExecutor context manager."""
@pytest.mark.asyncio
async def test_executor_execute(self):
"""Test executing within context manager."""
mock_guardian = AsyncMock()
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id=None,
)
mock_emergency = AsyncMock()
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
async def handler() -> str:
return "success"
wrapper.register_tool_handler("test_tool", handler)
call = MCPToolCall(tool_name="test_tool", arguments={})
async with SafeToolExecutor(wrapper, call, "agent-1") as executor:
result = await executor.execute()
assert result.success is True
assert result.result == "success"
@pytest.mark.asyncio
async def test_executor_result_property(self):
"""Test accessing result via property."""
mock_guardian = AsyncMock()
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id=None,
)
mock_emergency = AsyncMock()
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
wrapper.register_tool_handler("tool", lambda: "data")
call = MCPToolCall(tool_name="tool", arguments={})
executor = SafeToolExecutor(wrapper, call, "agent-1")
# Before execution
assert executor.result is None
async with executor:
await executor.execute()
# After execution
assert executor.result is not None
assert executor.result.success is True
@pytest.mark.asyncio
async def test_executor_with_autonomy_level(self):
"""Test executor with custom autonomy level."""
mock_guardian = AsyncMock()
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id=None,
)
mock_emergency = AsyncMock()
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
wrapper.register_tool_handler("tool", lambda: None)
call = MCPToolCall(tool_name="tool", arguments={})
async with SafeToolExecutor(
wrapper, call, "agent-1", AutonomyLevel.AUTONOMOUS
) as executor:
await executor.execute()
# Check that guardian was called with correct autonomy level
mock_guardian.validate.assert_called_once()
action = mock_guardian.validate.call_args[0][0]
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
class TestCreateMCPWrapper:
"""Tests for create_mcp_wrapper factory function."""
@pytest.mark.asyncio
async def test_create_wrapper_with_guardian(self):
"""Test creating wrapper with provided guardian."""
mock_guardian = AsyncMock()
with patch(
"app.services.safety.mcp.integration.get_emergency_controls"
) as mock_get_emergency:
mock_get_emergency.return_value = AsyncMock()
wrapper = await create_mcp_wrapper(guardian=mock_guardian)
assert wrapper._guardian is mock_guardian
@pytest.mark.asyncio
async def test_create_wrapper_default_guardian(self):
"""Test creating wrapper with default guardian."""
with (
patch(
"app.services.safety.mcp.integration.get_safety_guardian"
) as mock_get_guardian,
patch(
"app.services.safety.mcp.integration.get_emergency_controls"
) as mock_get_emergency,
):
mock_guardian = AsyncMock()
mock_get_guardian.return_value = mock_guardian
mock_get_emergency.return_value = AsyncMock()
wrapper = await create_mcp_wrapper()
assert wrapper._guardian is mock_guardian
mock_get_guardian.assert_called_once()
class TestLazyGetters:
"""Tests for lazy getter methods."""
@pytest.mark.asyncio
async def test_get_guardian_lazy(self):
"""Test lazy guardian initialization."""
wrapper = MCPSafetyWrapper()
with patch(
"app.services.safety.mcp.integration.get_safety_guardian"
) as mock_get:
mock_guardian = AsyncMock()
mock_get.return_value = mock_guardian
guardian = await wrapper._get_guardian()
assert guardian is mock_guardian
mock_get.assert_called_once()
@pytest.mark.asyncio
async def test_get_guardian_cached(self):
"""Test guardian is cached after first access."""
mock_guardian = AsyncMock()
wrapper = MCPSafetyWrapper(guardian=mock_guardian)
guardian = await wrapper._get_guardian()
assert guardian is mock_guardian
@pytest.mark.asyncio
async def test_get_emergency_controls_lazy(self):
"""Test lazy emergency controls initialization."""
wrapper = MCPSafetyWrapper()
with patch(
"app.services.safety.mcp.integration.get_emergency_controls"
) as mock_get:
mock_emergency = AsyncMock()
mock_get.return_value = mock_emergency
emergency = await wrapper._get_emergency_controls()
assert emergency is mock_emergency
mock_get.assert_called_once()
@pytest.mark.asyncio
async def test_get_emergency_controls_cached(self):
"""Test emergency controls is cached after first access."""
mock_emergency = AsyncMock()
wrapper = MCPSafetyWrapper(emergency_controls=mock_emergency)
emergency = await wrapper._get_emergency_controls()
assert emergency is mock_emergency
class TestEdgeCases:
"""Tests for edge cases and error handling."""
@pytest.mark.asyncio
async def test_execute_with_safety_error(self):
"""Test handling SafetyError from guardian."""
from app.services.safety.exceptions import SafetyError
mock_guardian = AsyncMock()
mock_guardian.validate.side_effect = SafetyError("Internal safety error")
mock_emergency = AsyncMock()
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
call = MCPToolCall(tool_name="test", arguments={})
result = await wrapper.execute(call, "agent-1")
assert result.success is False
assert "Internal safety error" in result.error
assert result.safety_decision == SafetyDecision.DENY
@pytest.mark.asyncio
async def test_execute_with_checkpoint_id(self):
"""Test that checkpoint_id is propagated to result."""
mock_guardian = AsyncMock()
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id="checkpoint-abc",
)
mock_emergency = AsyncMock()
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
wrapper.register_tool_handler("tool", lambda: "result")
call = MCPToolCall(tool_name="tool", arguments={})
result = await wrapper.execute(call, "agent-1")
assert result.success is True
assert result.checkpoint_id == "checkpoint-abc"
def test_destructive_tools_constant(self):
"""Test DESTRUCTIVE_TOOLS class constant."""
assert "file_write" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
assert "file_delete" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
assert "shell_execute" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
assert "git_push" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
def test_read_only_tools_constant(self):
"""Test READ_ONLY_TOOLS class constant."""
assert "file_read" in MCPSafetyWrapper.READ_ONLY_TOOLS
assert "database_query" in MCPSafetyWrapper.READ_ONLY_TOOLS
assert "git_status" in MCPSafetyWrapper.READ_ONLY_TOOLS
assert "search" in MCPSafetyWrapper.READ_ONLY_TOOLS
@pytest.mark.asyncio
async def test_scope_with_project_id(self):
"""Test that scope is set correctly with project_id."""
mock_guardian = AsyncMock()
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id=None,
)
mock_emergency = AsyncMock()
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
wrapper.register_tool_handler("tool", lambda: None)
call = MCPToolCall(
tool_name="tool",
arguments={},
project_id="proj-123",
)
await wrapper.execute(call, "agent-1")
# Verify emergency check was called with project scope
mock_emergency.check_allowed.assert_called_once()
call_kwargs = mock_emergency.check_allowed.call_args
assert "project:proj-123" in str(call_kwargs)
@pytest.mark.asyncio
async def test_scope_without_project_id(self):
"""Test that scope falls back to agent when no project_id."""
mock_guardian = AsyncMock()
mock_guardian.validate.return_value = MagicMock(
decision=SafetyDecision.ALLOW,
reasons=[],
approval_id=None,
checkpoint_id=None,
)
mock_emergency = AsyncMock()
wrapper = MCPSafetyWrapper(
guardian=mock_guardian,
emergency_controls=mock_emergency,
)
wrapper.register_tool_handler("tool", lambda: None)
call = MCPToolCall(
tool_name="tool",
arguments={},
# No project_id
)
await wrapper.execute(call, "agent-555")
# Verify emergency check was called with agent scope
mock_emergency.check_allowed.assert_called_once()
call_kwargs = mock_emergency.check_allowed.call_args
assert "agent:agent-555" in str(call_kwargs)

View File

@@ -0,0 +1,747 @@
"""
Tests for Safety Metrics Collector.
Tests cover:
- MetricType, MetricValue, HistogramBucket data structures
- SafetyMetrics counters, gauges, histograms
- Prometheus format export
- Summary and reset operations
- Singleton pattern and convenience functions
"""
import pytest
import pytest_asyncio
from app.services.safety.metrics.collector import (
HistogramBucket,
MetricType,
MetricValue,
SafetyMetrics,
get_safety_metrics,
record_mcp_call,
record_validation,
)
class TestMetricType:
"""Tests for MetricType enum."""
def test_metric_types_exist(self):
"""Test all metric types are defined."""
assert MetricType.COUNTER == "counter"
assert MetricType.GAUGE == "gauge"
assert MetricType.HISTOGRAM == "histogram"
def test_metric_type_is_string(self):
"""Test MetricType values are strings."""
assert isinstance(MetricType.COUNTER.value, str)
assert isinstance(MetricType.GAUGE.value, str)
assert isinstance(MetricType.HISTOGRAM.value, str)
class TestMetricValue:
"""Tests for MetricValue dataclass."""
def test_metric_value_creation(self):
"""Test creating a metric value."""
mv = MetricValue(
name="test_metric",
metric_type=MetricType.COUNTER,
value=42.0,
labels={"env": "test"},
)
assert mv.name == "test_metric"
assert mv.metric_type == MetricType.COUNTER
assert mv.value == 42.0
assert mv.labels == {"env": "test"}
assert mv.timestamp is not None
def test_metric_value_defaults(self):
"""Test metric value default values."""
mv = MetricValue(
name="test",
metric_type=MetricType.GAUGE,
value=0.0,
)
assert mv.labels == {}
assert mv.timestamp is not None
class TestHistogramBucket:
"""Tests for HistogramBucket dataclass."""
def test_histogram_bucket_creation(self):
"""Test creating a histogram bucket."""
bucket = HistogramBucket(le=0.5, count=10)
assert bucket.le == 0.5
assert bucket.count == 10
def test_histogram_bucket_defaults(self):
"""Test histogram bucket default count."""
bucket = HistogramBucket(le=1.0)
assert bucket.le == 1.0
assert bucket.count == 0
def test_histogram_bucket_infinity(self):
"""Test histogram bucket with infinity."""
bucket = HistogramBucket(le=float("inf"))
assert bucket.le == float("inf")
class TestSafetyMetricsCounters:
"""Tests for SafetyMetrics counter methods."""
@pytest_asyncio.fixture
async def metrics(self):
"""Create fresh metrics instance."""
return SafetyMetrics()
@pytest.mark.asyncio
async def test_inc_validations(self, metrics):
"""Test incrementing validations counter."""
await metrics.inc_validations("allow")
await metrics.inc_validations("allow")
await metrics.inc_validations("deny", agent_id="agent-1")
summary = await metrics.get_summary()
assert summary["total_validations"] == 3
assert summary["denied_validations"] == 1
@pytest.mark.asyncio
async def test_inc_approvals_requested(self, metrics):
"""Test incrementing approval requests counter."""
await metrics.inc_approvals_requested("normal")
await metrics.inc_approvals_requested("urgent")
await metrics.inc_approvals_requested() # default
summary = await metrics.get_summary()
assert summary["approval_requests"] == 3
@pytest.mark.asyncio
async def test_inc_approvals_granted(self, metrics):
"""Test incrementing approvals granted counter."""
await metrics.inc_approvals_granted()
await metrics.inc_approvals_granted()
summary = await metrics.get_summary()
assert summary["approvals_granted"] == 2
@pytest.mark.asyncio
async def test_inc_approvals_denied(self, metrics):
"""Test incrementing approvals denied counter."""
await metrics.inc_approvals_denied("timeout")
await metrics.inc_approvals_denied("policy")
await metrics.inc_approvals_denied() # default manual
summary = await metrics.get_summary()
assert summary["approvals_denied"] == 3
@pytest.mark.asyncio
async def test_inc_rate_limit_exceeded(self, metrics):
"""Test incrementing rate limit exceeded counter."""
await metrics.inc_rate_limit_exceeded("requests_per_minute")
await metrics.inc_rate_limit_exceeded("tokens_per_hour")
summary = await metrics.get_summary()
assert summary["rate_limit_hits"] == 2
@pytest.mark.asyncio
async def test_inc_budget_exceeded(self, metrics):
"""Test incrementing budget exceeded counter."""
await metrics.inc_budget_exceeded("daily_cost")
await metrics.inc_budget_exceeded("monthly_tokens")
summary = await metrics.get_summary()
assert summary["budget_exceeded"] == 2
@pytest.mark.asyncio
async def test_inc_loops_detected(self, metrics):
"""Test incrementing loops detected counter."""
await metrics.inc_loops_detected("repetition")
await metrics.inc_loops_detected("pattern")
summary = await metrics.get_summary()
assert summary["loops_detected"] == 2
@pytest.mark.asyncio
async def test_inc_emergency_events(self, metrics):
"""Test incrementing emergency events counter."""
await metrics.inc_emergency_events("pause", "project-1")
await metrics.inc_emergency_events("stop", "agent-2")
summary = await metrics.get_summary()
assert summary["emergency_events"] == 2
@pytest.mark.asyncio
async def test_inc_content_filtered(self, metrics):
"""Test incrementing content filtered counter."""
await metrics.inc_content_filtered("profanity", "blocked")
await metrics.inc_content_filtered("pii", "redacted")
summary = await metrics.get_summary()
assert summary["content_filtered"] == 2
@pytest.mark.asyncio
async def test_inc_checkpoints_created(self, metrics):
"""Test incrementing checkpoints created counter."""
await metrics.inc_checkpoints_created()
await metrics.inc_checkpoints_created()
await metrics.inc_checkpoints_created()
summary = await metrics.get_summary()
assert summary["checkpoints_created"] == 3
@pytest.mark.asyncio
async def test_inc_rollbacks_executed(self, metrics):
"""Test incrementing rollbacks executed counter."""
await metrics.inc_rollbacks_executed(success=True)
await metrics.inc_rollbacks_executed(success=False)
summary = await metrics.get_summary()
assert summary["rollbacks_executed"] == 2
@pytest.mark.asyncio
async def test_inc_mcp_calls(self, metrics):
"""Test incrementing MCP calls counter."""
await metrics.inc_mcp_calls("search_knowledge", success=True)
await metrics.inc_mcp_calls("run_code", success=False)
summary = await metrics.get_summary()
assert summary["mcp_calls"] == 2
class TestSafetyMetricsGauges:
"""Tests for SafetyMetrics gauge methods."""
@pytest_asyncio.fixture
async def metrics(self):
"""Create fresh metrics instance."""
return SafetyMetrics()
@pytest.mark.asyncio
async def test_set_budget_remaining(self, metrics):
"""Test setting budget remaining gauge."""
await metrics.set_budget_remaining("project-1", "daily_cost", 50.0)
all_metrics = await metrics.get_all_metrics()
gauge_metrics = [m for m in all_metrics if m.name == "safety_budget_remaining"]
assert len(gauge_metrics) == 1
assert gauge_metrics[0].value == 50.0
assert gauge_metrics[0].labels["scope"] == "project-1"
assert gauge_metrics[0].labels["budget_type"] == "daily_cost"
@pytest.mark.asyncio
async def test_set_rate_limit_remaining(self, metrics):
"""Test setting rate limit remaining gauge."""
await metrics.set_rate_limit_remaining("agent-1", "requests_per_minute", 45)
all_metrics = await metrics.get_all_metrics()
gauge_metrics = [
m for m in all_metrics if m.name == "safety_rate_limit_remaining"
]
assert len(gauge_metrics) == 1
assert gauge_metrics[0].value == 45.0
@pytest.mark.asyncio
async def test_set_pending_approvals(self, metrics):
"""Test setting pending approvals gauge."""
await metrics.set_pending_approvals(5)
summary = await metrics.get_summary()
assert summary["pending_approvals"] == 5
@pytest.mark.asyncio
async def test_set_active_checkpoints(self, metrics):
"""Test setting active checkpoints gauge."""
await metrics.set_active_checkpoints(3)
summary = await metrics.get_summary()
assert summary["active_checkpoints"] == 3
@pytest.mark.asyncio
async def test_set_emergency_state(self, metrics):
"""Test setting emergency state gauge."""
await metrics.set_emergency_state("project-1", "normal")
await metrics.set_emergency_state("project-2", "paused")
await metrics.set_emergency_state("project-3", "stopped")
await metrics.set_emergency_state("project-4", "unknown")
all_metrics = await metrics.get_all_metrics()
state_metrics = [m for m in all_metrics if m.name == "safety_emergency_state"]
assert len(state_metrics) == 4
# Check state values
values_by_scope = {m.labels["scope"]: m.value for m in state_metrics}
assert values_by_scope["project-1"] == 0.0 # normal
assert values_by_scope["project-2"] == 1.0 # paused
assert values_by_scope["project-3"] == 2.0 # stopped
assert values_by_scope["project-4"] == -1.0 # unknown
class TestSafetyMetricsHistograms:
"""Tests for SafetyMetrics histogram methods."""
@pytest_asyncio.fixture
async def metrics(self):
"""Create fresh metrics instance."""
return SafetyMetrics()
@pytest.mark.asyncio
async def test_observe_validation_latency(self, metrics):
"""Test observing validation latency."""
await metrics.observe_validation_latency(0.05)
await metrics.observe_validation_latency(0.15)
await metrics.observe_validation_latency(0.5)
all_metrics = await metrics.get_all_metrics()
count_metric = next(
(m for m in all_metrics if m.name == "validation_latency_seconds_count"),
None,
)
assert count_metric is not None
assert count_metric.value == 3.0
sum_metric = next(
(m for m in all_metrics if m.name == "validation_latency_seconds_sum"),
None,
)
assert sum_metric is not None
assert abs(sum_metric.value - 0.7) < 0.001
@pytest.mark.asyncio
async def test_observe_approval_latency(self, metrics):
"""Test observing approval latency."""
await metrics.observe_approval_latency(1.5)
await metrics.observe_approval_latency(3.0)
all_metrics = await metrics.get_all_metrics()
count_metric = next(
(m for m in all_metrics if m.name == "approval_latency_seconds_count"),
None,
)
assert count_metric is not None
assert count_metric.value == 2.0
@pytest.mark.asyncio
async def test_observe_mcp_execution_latency(self, metrics):
"""Test observing MCP execution latency."""
await metrics.observe_mcp_execution_latency(0.02)
all_metrics = await metrics.get_all_metrics()
count_metric = next(
(m for m in all_metrics if m.name == "mcp_execution_latency_seconds_count"),
None,
)
assert count_metric is not None
assert count_metric.value == 1.0
@pytest.mark.asyncio
async def test_histogram_bucket_updates(self, metrics):
"""Test that histogram buckets are updated correctly."""
# Add values to test bucket distribution
await metrics.observe_validation_latency(0.005) # <= 0.01
await metrics.observe_validation_latency(0.03) # <= 0.05
await metrics.observe_validation_latency(0.07) # <= 0.1
await metrics.observe_validation_latency(15.0) # <= inf
prometheus = await metrics.get_prometheus_format()
# Check that bucket counts are in output
assert "validation_latency_seconds_bucket" in prometheus
assert "le=" in prometheus
class TestSafetyMetricsExport:
"""Tests for SafetyMetrics export methods."""
@pytest_asyncio.fixture
async def metrics(self):
"""Create fresh metrics instance with some data."""
m = SafetyMetrics()
# Add some counters
await m.inc_validations("allow")
await m.inc_validations("deny", agent_id="agent-1")
# Add some gauges
await m.set_pending_approvals(3)
await m.set_budget_remaining("proj-1", "daily", 100.0)
# Add some histogram values
await m.observe_validation_latency(0.1)
return m
@pytest.mark.asyncio
async def test_get_all_metrics(self, metrics):
"""Test getting all metrics."""
all_metrics = await metrics.get_all_metrics()
assert len(all_metrics) > 0
assert all(isinstance(m, MetricValue) for m in all_metrics)
# Check we have different types
types = {m.metric_type for m in all_metrics}
assert MetricType.COUNTER in types
assert MetricType.GAUGE in types
@pytest.mark.asyncio
async def test_get_prometheus_format(self, metrics):
"""Test Prometheus format export."""
output = await metrics.get_prometheus_format()
assert isinstance(output, str)
assert "# TYPE" in output
assert "counter" in output
assert "gauge" in output
assert "safety_validations_total" in output
assert "safety_pending_approvals" in output
@pytest.mark.asyncio
async def test_prometheus_format_with_labels(self, metrics):
"""Test Prometheus format includes labels correctly."""
output = await metrics.get_prometheus_format()
# Counter with labels
assert "decision=allow" in output or "decision=deny" in output
@pytest.mark.asyncio
async def test_prometheus_format_histogram_buckets(self, metrics):
"""Test Prometheus format includes histogram buckets."""
output = await metrics.get_prometheus_format()
assert "histogram" in output
assert "_bucket" in output
assert "le=" in output
assert "+Inf" in output
@pytest.mark.asyncio
async def test_get_summary(self, metrics):
"""Test getting summary."""
summary = await metrics.get_summary()
assert "total_validations" in summary
assert "denied_validations" in summary
assert "approval_requests" in summary
assert "pending_approvals" in summary
assert "active_checkpoints" in summary
assert summary["total_validations"] == 2
assert summary["denied_validations"] == 1
assert summary["pending_approvals"] == 3
@pytest.mark.asyncio
async def test_summary_empty_counters(self):
"""Test summary with no data."""
metrics = SafetyMetrics()
summary = await metrics.get_summary()
assert summary["total_validations"] == 0
assert summary["denied_validations"] == 0
assert summary["pending_approvals"] == 0
class TestSafetyMetricsReset:
"""Tests for SafetyMetrics reset."""
@pytest.mark.asyncio
async def test_reset_clears_counters(self):
"""Test reset clears all counters."""
metrics = SafetyMetrics()
await metrics.inc_validations("allow")
await metrics.inc_approvals_granted()
await metrics.set_pending_approvals(5)
await metrics.observe_validation_latency(0.1)
await metrics.reset()
summary = await metrics.get_summary()
assert summary["total_validations"] == 0
assert summary["approvals_granted"] == 0
assert summary["pending_approvals"] == 0
@pytest.mark.asyncio
async def test_reset_reinitializes_histogram_buckets(self):
"""Test reset reinitializes histogram buckets."""
metrics = SafetyMetrics()
await metrics.observe_validation_latency(0.1)
await metrics.reset()
# After reset, histogram buckets should be reinitialized
prometheus = await metrics.get_prometheus_format()
assert "validation_latency_seconds" in prometheus
class TestParseLabels:
"""Tests for _parse_labels helper method."""
def test_parse_empty_labels(self):
"""Test parsing empty labels string."""
metrics = SafetyMetrics()
result = metrics._parse_labels("")
assert result == {}
def test_parse_single_label(self):
"""Test parsing single label."""
metrics = SafetyMetrics()
result = metrics._parse_labels("key=value")
assert result == {"key": "value"}
def test_parse_multiple_labels(self):
"""Test parsing multiple labels."""
metrics = SafetyMetrics()
result = metrics._parse_labels("a=1,b=2,c=3")
assert result == {"a": "1", "b": "2", "c": "3"}
def test_parse_labels_with_spaces(self):
"""Test parsing labels with spaces."""
metrics = SafetyMetrics()
result = metrics._parse_labels(" key = value , foo = bar ")
assert result == {"key": "value", "foo": "bar"}
def test_parse_labels_with_equals_in_value(self):
"""Test parsing labels with = in value."""
metrics = SafetyMetrics()
result = metrics._parse_labels("query=a=b")
assert result == {"query": "a=b"}
def test_parse_invalid_label(self):
"""Test parsing invalid label without equals."""
metrics = SafetyMetrics()
result = metrics._parse_labels("no_equals")
assert result == {}
class TestHistogramBucketInit:
"""Tests for histogram bucket initialization."""
def test_histogram_buckets_initialized(self):
"""Test that histogram buckets are initialized."""
metrics = SafetyMetrics()
assert "validation_latency_seconds" in metrics._histogram_buckets
assert "approval_latency_seconds" in metrics._histogram_buckets
assert "mcp_execution_latency_seconds" in metrics._histogram_buckets
def test_histogram_buckets_have_correct_values(self):
"""Test histogram buckets have correct boundary values."""
metrics = SafetyMetrics()
buckets = metrics._histogram_buckets["validation_latency_seconds"]
# Check first few and last bucket
assert buckets[0].le == 0.01
assert buckets[1].le == 0.05
assert buckets[-1].le == float("inf")
# Check all have zero initial count
assert all(b.count == 0 for b in buckets)
class TestSingletonAndConvenience:
"""Tests for singleton pattern and convenience functions."""
@pytest.mark.asyncio
async def test_get_safety_metrics_returns_same_instance(self):
"""Test get_safety_metrics returns singleton."""
# Reset the module-level singleton for this test
import app.services.safety.metrics.collector as collector_module
collector_module._metrics = None
m1 = await get_safety_metrics()
m2 = await get_safety_metrics()
assert m1 is m2
@pytest.mark.asyncio
async def test_record_validation_convenience(self):
"""Test record_validation convenience function."""
import app.services.safety.metrics.collector as collector_module
collector_module._metrics = None # Reset
await record_validation("allow")
await record_validation("deny", agent_id="test-agent")
metrics = await get_safety_metrics()
summary = await metrics.get_summary()
assert summary["total_validations"] == 2
assert summary["denied_validations"] == 1
@pytest.mark.asyncio
async def test_record_mcp_call_convenience(self):
"""Test record_mcp_call convenience function."""
import app.services.safety.metrics.collector as collector_module
collector_module._metrics = None # Reset
await record_mcp_call("search_knowledge", success=True, latency_ms=50)
await record_mcp_call("run_code", success=False, latency_ms=100)
metrics = await get_safety_metrics()
summary = await metrics.get_summary()
assert summary["mcp_calls"] == 2
class TestConcurrency:
"""Tests for concurrent metric updates."""
@pytest.mark.asyncio
async def test_concurrent_counter_increments(self):
"""Test concurrent counter increments are safe."""
import asyncio
metrics = SafetyMetrics()
async def increment_many():
for _ in range(100):
await metrics.inc_validations("allow")
# Run 10 concurrent tasks each incrementing 100 times
await asyncio.gather(*[increment_many() for _ in range(10)])
summary = await metrics.get_summary()
assert summary["total_validations"] == 1000
@pytest.mark.asyncio
async def test_concurrent_gauge_updates(self):
"""Test concurrent gauge updates are safe."""
import asyncio
metrics = SafetyMetrics()
async def update_gauge(value):
await metrics.set_pending_approvals(value)
# Run concurrent gauge updates
await asyncio.gather(*[update_gauge(i) for i in range(100)])
# Final value should be one of the updates (last one wins)
summary = await metrics.get_summary()
assert 0 <= summary["pending_approvals"] < 100
@pytest.mark.asyncio
async def test_concurrent_histogram_observations(self):
"""Test concurrent histogram observations are safe."""
import asyncio
metrics = SafetyMetrics()
async def observe_many():
for i in range(100):
await metrics.observe_validation_latency(i / 1000)
await asyncio.gather(*[observe_many() for _ in range(10)])
all_metrics = await metrics.get_all_metrics()
count_metric = next(
(m for m in all_metrics if m.name == "validation_latency_seconds_count"),
None,
)
assert count_metric is not None
assert count_metric.value == 1000.0
class TestEdgeCases:
"""Tests for edge cases."""
@pytest.mark.asyncio
async def test_very_large_counter_value(self):
"""Test handling very large counter values."""
metrics = SafetyMetrics()
for _ in range(10000):
await metrics.inc_validations("allow")
summary = await metrics.get_summary()
assert summary["total_validations"] == 10000
@pytest.mark.asyncio
async def test_zero_and_negative_gauge_values(self):
"""Test zero and negative gauge values."""
metrics = SafetyMetrics()
await metrics.set_budget_remaining("project", "cost", 0.0)
await metrics.set_budget_remaining("project2", "cost", -10.0)
all_metrics = await metrics.get_all_metrics()
gauges = [m for m in all_metrics if m.name == "safety_budget_remaining"]
values = {m.labels.get("scope"): m.value for m in gauges}
assert values["project"] == 0.0
assert values["project2"] == -10.0
@pytest.mark.asyncio
async def test_very_small_histogram_values(self):
"""Test very small histogram values."""
metrics = SafetyMetrics()
await metrics.observe_validation_latency(0.0001) # 0.1ms
all_metrics = await metrics.get_all_metrics()
sum_metric = next(
(m for m in all_metrics if m.name == "validation_latency_seconds_sum"),
None,
)
assert sum_metric is not None
assert abs(sum_metric.value - 0.0001) < 0.00001
@pytest.mark.asyncio
async def test_special_characters_in_labels(self):
"""Test special characters in label values."""
metrics = SafetyMetrics()
await metrics.inc_validations("allow", agent_id="agent/with/slashes")
all_metrics = await metrics.get_all_metrics()
counters = [m for m in all_metrics if m.name == "safety_validations_total"]
# Should have the metric with special chars
assert len(counters) > 0
@pytest.mark.asyncio
async def test_empty_histogram_export(self):
"""Test exporting histogram with no observations."""
metrics = SafetyMetrics()
# No observations, but histogram buckets should still exist
prometheus = await metrics.get_prometheus_format()
assert "validation_latency_seconds" in prometheus
assert "le=" in prometheus
@pytest.mark.asyncio
async def test_prometheus_format_empty_label_value(self):
"""Test Prometheus format with empty label metrics."""
metrics = SafetyMetrics()
await metrics.inc_approvals_granted() # Uses empty string as label
prometheus = await metrics.get_prometheus_format()
assert "safety_approvals_granted_total" in prometheus
@pytest.mark.asyncio
async def test_multiple_resets(self):
"""Test multiple resets don't cause issues."""
metrics = SafetyMetrics()
await metrics.inc_validations("allow")
await metrics.reset()
await metrics.reset()
await metrics.reset()
summary = await metrics.get_summary()
assert summary["total_validations"] == 0

View File

@@ -0,0 +1,933 @@
"""Tests for Permission Manager.
Tests cover:
- PermissionGrant: creation, expiry, matching, hierarchy
- PermissionManager: grant, revoke, check, require, list, defaults
- Edge cases: wildcards, expiration, default deny/allow
"""
from datetime import datetime, timedelta
import pytest
import pytest_asyncio
from app.services.safety.exceptions import PermissionDeniedError
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
PermissionLevel,
ResourceType,
)
from app.services.safety.permissions.manager import PermissionGrant, PermissionManager
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def action_metadata() -> ActionMetadata:
"""Create standard action metadata for tests."""
return ActionMetadata(
agent_id="test-agent",
project_id="test-project",
session_id="test-session",
)
@pytest_asyncio.fixture
async def permission_manager() -> PermissionManager:
"""Create a PermissionManager for testing."""
return PermissionManager(default_deny=True)
@pytest_asyncio.fixture
async def permissive_manager() -> PermissionManager:
"""Create a PermissionManager with default_deny=False."""
return PermissionManager(default_deny=False)
# ============================================================================
# PermissionGrant Tests
# ============================================================================
class TestPermissionGrant:
"""Tests for the PermissionGrant class."""
def test_grant_creation(self) -> None:
"""Test basic grant creation."""
grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
granted_by="admin",
reason="Read access to data directory",
)
assert grant.id is not None
assert grant.agent_id == "agent-1"
assert grant.resource_pattern == "/data/*"
assert grant.resource_type == ResourceType.FILE
assert grant.level == PermissionLevel.READ
assert grant.granted_by == "admin"
assert grant.reason == "Read access to data directory"
assert grant.expires_at is None
assert grant.created_at is not None
def test_grant_with_expiration(self) -> None:
"""Test grant with expiration time."""
future = datetime.utcnow() + timedelta(hours=1)
grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.API,
level=PermissionLevel.EXECUTE,
expires_at=future,
)
assert grant.expires_at == future
assert grant.is_expired() is False
def test_is_expired_no_expiration(self) -> None:
"""Test is_expired with no expiration set."""
grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
assert grant.is_expired() is False
def test_is_expired_future(self) -> None:
"""Test is_expired with future expiration."""
grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
expires_at=datetime.utcnow() + timedelta(hours=1),
)
assert grant.is_expired() is False
def test_is_expired_past(self) -> None:
"""Test is_expired with past expiration."""
grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
expires_at=datetime.utcnow() - timedelta(hours=1),
)
assert grant.is_expired() is True
def test_matches_exact(self) -> None:
"""Test matching with exact pattern."""
grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="/data/file.txt",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
assert grant.matches("/data/other.txt", ResourceType.FILE) is False
def test_matches_wildcard(self) -> None:
"""Test matching with wildcard pattern."""
grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
# fnmatch's * matches everything including /
assert grant.matches("/data/subdir/file.txt", ResourceType.FILE) is True
assert grant.matches("/other/file.txt", ResourceType.FILE) is False
def test_matches_recursive_wildcard(self) -> None:
"""Test matching with recursive pattern."""
grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="/data/**",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
# fnmatch treats ** similar to * - both match everything including /
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
assert grant.matches("/data/subdir/file.txt", ResourceType.FILE) is True
def test_matches_wrong_resource_type(self) -> None:
"""Test matching fails with wrong resource type."""
grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
# Same pattern but different resource type
assert grant.matches("/data/table", ResourceType.DATABASE) is False
def test_allows_hierarchy(self) -> None:
"""Test permission level hierarchy."""
admin_grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.FILE,
level=PermissionLevel.ADMIN,
)
# ADMIN allows all levels
assert admin_grant.allows(PermissionLevel.NONE) is True
assert admin_grant.allows(PermissionLevel.READ) is True
assert admin_grant.allows(PermissionLevel.WRITE) is True
assert admin_grant.allows(PermissionLevel.EXECUTE) is True
assert admin_grant.allows(PermissionLevel.DELETE) is True
assert admin_grant.allows(PermissionLevel.ADMIN) is True
def test_allows_read_only(self) -> None:
"""Test READ grant only allows READ and NONE."""
read_grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
assert read_grant.allows(PermissionLevel.NONE) is True
assert read_grant.allows(PermissionLevel.READ) is True
assert read_grant.allows(PermissionLevel.WRITE) is False
assert read_grant.allows(PermissionLevel.EXECUTE) is False
assert read_grant.allows(PermissionLevel.DELETE) is False
assert read_grant.allows(PermissionLevel.ADMIN) is False
def test_allows_write_includes_read(self) -> None:
"""Test WRITE grant includes READ."""
write_grant = PermissionGrant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.FILE,
level=PermissionLevel.WRITE,
)
assert write_grant.allows(PermissionLevel.READ) is True
assert write_grant.allows(PermissionLevel.WRITE) is True
assert write_grant.allows(PermissionLevel.EXECUTE) is False
# ============================================================================
# PermissionManager Tests
# ============================================================================
class TestPermissionManager:
"""Tests for the PermissionManager class."""
@pytest.mark.asyncio
async def test_grant_creates_permission(
self,
permission_manager: PermissionManager,
) -> None:
"""Test granting a permission."""
grant = await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
granted_by="admin",
reason="Read access",
)
assert grant.id is not None
assert grant.agent_id == "agent-1"
assert grant.resource_pattern == "/data/*"
@pytest.mark.asyncio
async def test_grant_with_duration(
self,
permission_manager: PermissionManager,
) -> None:
"""Test granting a temporary permission."""
grant = await permission_manager.grant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.API,
level=PermissionLevel.EXECUTE,
duration_seconds=3600, # 1 hour
)
assert grant.expires_at is not None
assert grant.is_expired() is False
@pytest.mark.asyncio
async def test_revoke_by_id(
self,
permission_manager: PermissionManager,
) -> None:
"""Test revoking a grant by ID."""
grant = await permission_manager.grant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
success = await permission_manager.revoke(grant.id)
assert success is True
# Verify grant is removed
grants = await permission_manager.list_grants(agent_id="agent-1")
assert len(grants) == 0
@pytest.mark.asyncio
async def test_revoke_nonexistent(
self,
permission_manager: PermissionManager,
) -> None:
"""Test revoking a non-existent grant."""
success = await permission_manager.revoke("nonexistent-id")
assert success is False
@pytest.mark.asyncio
async def test_revoke_all_for_agent(
self,
permission_manager: PermissionManager,
) -> None:
"""Test revoking all permissions for an agent."""
# Grant multiple permissions
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/api/*",
resource_type=ResourceType.API,
level=PermissionLevel.EXECUTE,
)
await permission_manager.grant(
agent_id="agent-2",
resource_pattern="*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
revoked = await permission_manager.revoke_all("agent-1")
assert revoked == 2
# Verify agent-1 grants are gone
grants = await permission_manager.list_grants(agent_id="agent-1")
assert len(grants) == 0
# Verify agent-2 grant remains
grants = await permission_manager.list_grants(agent_id="agent-2")
assert len(grants) == 1
@pytest.mark.asyncio
async def test_revoke_all_no_grants(
self,
permission_manager: PermissionManager,
) -> None:
"""Test revoking all when no grants exist."""
revoked = await permission_manager.revoke_all("nonexistent-agent")
assert revoked == 0
@pytest.mark.asyncio
async def test_check_granted(
self,
permission_manager: PermissionManager,
) -> None:
"""Test checking a granted permission."""
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
allowed = await permission_manager.check(
agent_id="agent-1",
resource="/data/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
assert allowed is True
@pytest.mark.asyncio
async def test_check_denied_default_deny(
self,
permission_manager: PermissionManager,
) -> None:
"""Test checking denied with default_deny=True."""
# No grants, should be denied
allowed = await permission_manager.check(
agent_id="agent-1",
resource="/data/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
assert allowed is False
@pytest.mark.asyncio
async def test_check_uses_default_permissions(
self,
permissive_manager: PermissionManager,
) -> None:
"""Test that default permissions apply when default_deny=False."""
# No explicit grants, but FILE default is READ
allowed = await permissive_manager.check(
agent_id="agent-1",
resource="/data/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
assert allowed is True
# But WRITE should fail
allowed = await permissive_manager.check(
agent_id="agent-1",
resource="/data/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.WRITE,
)
assert allowed is False
@pytest.mark.asyncio
async def test_check_shell_denied_by_default(
self,
permissive_manager: PermissionManager,
) -> None:
"""Test SHELL is denied by default (NONE level)."""
allowed = await permissive_manager.check(
agent_id="agent-1",
resource="rm -rf /",
resource_type=ResourceType.SHELL,
required_level=PermissionLevel.EXECUTE,
)
assert allowed is False
@pytest.mark.asyncio
async def test_check_expired_grant_ignored(
self,
permission_manager: PermissionManager,
) -> None:
"""Test that expired grants are ignored in checks."""
# Create an already-expired grant
grant = await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
duration_seconds=1, # Very short
)
# Manually expire it
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
allowed = await permission_manager.check(
agent_id="agent-1",
resource="/data/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
assert allowed is False
@pytest.mark.asyncio
async def test_check_insufficient_level(
self,
permission_manager: PermissionManager,
) -> None:
"""Test check fails when grant level is insufficient."""
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
# Try to get WRITE access with only READ grant
allowed = await permission_manager.check(
agent_id="agent-1",
resource="/data/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.WRITE,
)
assert allowed is False
@pytest.mark.asyncio
async def test_check_action_file_read(
self,
permission_manager: PermissionManager,
action_metadata: ActionMetadata,
) -> None:
"""Test check_action for file read."""
await permission_manager.grant(
agent_id="test-agent",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
action = ActionRequest(
action_type=ActionType.FILE_READ,
resource="/data/file.txt",
metadata=action_metadata,
)
allowed = await permission_manager.check_action(action)
assert allowed is True
@pytest.mark.asyncio
async def test_check_action_file_write(
self,
permission_manager: PermissionManager,
action_metadata: ActionMetadata,
) -> None:
"""Test check_action for file write."""
await permission_manager.grant(
agent_id="test-agent",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.WRITE,
)
action = ActionRequest(
action_type=ActionType.FILE_WRITE,
resource="/data/file.txt",
metadata=action_metadata,
)
allowed = await permission_manager.check_action(action)
assert allowed is True
@pytest.mark.asyncio
async def test_check_action_uses_tool_name_as_resource(
self,
permission_manager: PermissionManager,
action_metadata: ActionMetadata,
) -> None:
"""Test check_action uses tool_name when resource is None."""
await permission_manager.grant(
agent_id="test-agent",
resource_pattern="search_*",
resource_type=ResourceType.CUSTOM,
level=PermissionLevel.EXECUTE,
)
action = ActionRequest(
action_type=ActionType.TOOL_CALL,
tool_name="search_documents",
resource=None,
metadata=action_metadata,
)
allowed = await permission_manager.check_action(action)
assert allowed is True
@pytest.mark.asyncio
async def test_require_permission_granted(
self,
permission_manager: PermissionManager,
) -> None:
"""Test require_permission doesn't raise when granted."""
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
# Should not raise
await permission_manager.require_permission(
agent_id="agent-1",
resource="/data/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
@pytest.mark.asyncio
async def test_require_permission_denied(
self,
permission_manager: PermissionManager,
) -> None:
"""Test require_permission raises when denied."""
with pytest.raises(PermissionDeniedError) as exc_info:
await permission_manager.require_permission(
agent_id="agent-1",
resource="/secret/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
assert "/secret/file.txt" in str(exc_info.value)
assert exc_info.value.agent_id == "agent-1"
assert exc_info.value.required_permission == "read"
@pytest.mark.asyncio
async def test_list_grants_all(
self,
permission_manager: PermissionManager,
) -> None:
"""Test listing all grants."""
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
await permission_manager.grant(
agent_id="agent-2",
resource_pattern="/api/*",
resource_type=ResourceType.API,
level=PermissionLevel.EXECUTE,
)
grants = await permission_manager.list_grants()
assert len(grants) == 2
@pytest.mark.asyncio
async def test_list_grants_by_agent(
self,
permission_manager: PermissionManager,
) -> None:
"""Test listing grants filtered by agent."""
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
await permission_manager.grant(
agent_id="agent-2",
resource_pattern="/api/*",
resource_type=ResourceType.API,
level=PermissionLevel.EXECUTE,
)
grants = await permission_manager.list_grants(agent_id="agent-1")
assert len(grants) == 1
assert grants[0].agent_id == "agent-1"
@pytest.mark.asyncio
async def test_list_grants_by_resource_type(
self,
permission_manager: PermissionManager,
) -> None:
"""Test listing grants filtered by resource type."""
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/api/*",
resource_type=ResourceType.API,
level=PermissionLevel.EXECUTE,
)
grants = await permission_manager.list_grants(resource_type=ResourceType.FILE)
assert len(grants) == 1
assert grants[0].resource_type == ResourceType.FILE
@pytest.mark.asyncio
async def test_list_grants_excludes_expired(
self,
permission_manager: PermissionManager,
) -> None:
"""Test that list_grants excludes expired grants."""
# Create expired grant
grant = await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/old/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
duration_seconds=1,
)
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
# Create valid grant
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/new/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
grants = await permission_manager.list_grants()
assert len(grants) == 1
assert grants[0].resource_pattern == "/new/*"
def test_set_default_permission(
self,
) -> None:
"""Test setting default permission level."""
manager = PermissionManager(default_deny=False)
# Default for SHELL is NONE
assert manager._default_permissions[ResourceType.SHELL] == PermissionLevel.NONE
# Change it
manager.set_default_permission(ResourceType.SHELL, PermissionLevel.EXECUTE)
assert (
manager._default_permissions[ResourceType.SHELL] == PermissionLevel.EXECUTE
)
@pytest.mark.asyncio
async def test_set_default_permission_affects_checks(
self,
permissive_manager: PermissionManager,
) -> None:
"""Test that changing default permissions affects checks."""
# Initially SHELL is NONE
allowed = await permissive_manager.check(
agent_id="agent-1",
resource="ls",
resource_type=ResourceType.SHELL,
required_level=PermissionLevel.EXECUTE,
)
assert allowed is False
# Change default
permissive_manager.set_default_permission(
ResourceType.SHELL, PermissionLevel.EXECUTE
)
# Now should be allowed
allowed = await permissive_manager.check(
agent_id="agent-1",
resource="ls",
resource_type=ResourceType.SHELL,
required_level=PermissionLevel.EXECUTE,
)
assert allowed is True
# ============================================================================
# Edge Cases
# ============================================================================
class TestPermissionEdgeCases:
"""Edge cases that could reveal hidden bugs."""
@pytest.mark.asyncio
async def test_multiple_matching_grants(
self,
permission_manager: PermissionManager,
) -> None:
"""Test when multiple grants match - first sufficient one wins."""
# Grant READ on all files
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
# Also grant WRITE on specific path
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/writable/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.WRITE,
)
# Write on writable path should work
allowed = await permission_manager.check(
agent_id="agent-1",
resource="/data/writable/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.WRITE,
)
assert allowed is True
@pytest.mark.asyncio
async def test_wildcard_all_pattern(
self,
permission_manager: PermissionManager,
) -> None:
"""Test * pattern matches everything."""
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="*",
resource_type=ResourceType.FILE,
level=PermissionLevel.ADMIN,
)
allowed = await permission_manager.check(
agent_id="agent-1",
resource="/any/path/anywhere/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.DELETE,
)
# fnmatch's * matches everything including /
assert allowed is True
@pytest.mark.asyncio
async def test_question_mark_wildcard(
self,
permission_manager: PermissionManager,
) -> None:
"""Test ? wildcard matches single character."""
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="file?.txt",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
assert (
await permission_manager.check(
agent_id="agent-1",
resource="file1.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
is True
)
assert (
await permission_manager.check(
agent_id="agent-1",
resource="file10.txt", # Two characters, won't match
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
is False
)
@pytest.mark.asyncio
async def test_concurrent_grant_revoke(
self,
permission_manager: PermissionManager,
) -> None:
"""Test concurrent grant and revoke operations."""
async def grant_many():
grants = []
for i in range(10):
g = await permission_manager.grant(
agent_id="agent-1",
resource_pattern=f"/path{i}/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
grants.append(g)
return grants
async def revoke_many(grants):
for g in grants:
await permission_manager.revoke(g.id)
grants = await grant_many()
await revoke_many(grants)
# All should be revoked
remaining = await permission_manager.list_grants(agent_id="agent-1")
assert len(remaining) == 0
@pytest.mark.asyncio
async def test_check_action_with_no_resource_or_tool(
self,
permission_manager: PermissionManager,
action_metadata: ActionMetadata,
) -> None:
"""Test check_action when both resource and tool_name are None."""
await permission_manager.grant(
agent_id="test-agent",
resource_pattern="*",
resource_type=ResourceType.LLM,
level=PermissionLevel.EXECUTE,
)
action = ActionRequest(
action_type=ActionType.LLM_CALL,
resource=None,
tool_name=None,
metadata=action_metadata,
)
# Should use "*" as fallback
allowed = await permission_manager.check_action(action)
assert allowed is True
@pytest.mark.asyncio
async def test_cleanup_expired_called_on_check(
self,
permission_manager: PermissionManager,
) -> None:
"""Test that expired grants are cleaned up during check."""
# Create expired grant
grant = await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/old/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
duration_seconds=1,
)
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
# Create valid grant
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/new/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
# Run a check - this should trigger cleanup
await permission_manager.check(
agent_id="agent-1",
resource="/new/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
# Now verify expired grant was cleaned up
async with permission_manager._lock:
assert len(permission_manager._grants) == 1
assert permission_manager._grants[0].resource_pattern == "/new/*"
@pytest.mark.asyncio
async def test_check_wrong_agent_id(
self,
permission_manager: PermissionManager,
) -> None:
"""Test check fails for different agent."""
await permission_manager.grant(
agent_id="agent-1",
resource_pattern="/data/*",
resource_type=ResourceType.FILE,
level=PermissionLevel.READ,
)
# Different agent should not have access
allowed = await permission_manager.check(
agent_id="agent-2",
resource="/data/file.txt",
resource_type=ResourceType.FILE,
required_level=PermissionLevel.READ,
)
assert allowed is False

View File

@@ -0,0 +1,823 @@
"""Tests for Rollback Manager.
Tests cover:
- FileCheckpoint: state storage
- RollbackManager: checkpoint, rollback, cleanup
- TransactionContext: auto-rollback, commit, manual rollback
- Edge cases: non-existent files, partial failures, expiration
"""
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import pytest_asyncio
from app.services.safety.exceptions import RollbackError
from app.services.safety.models import (
ActionMetadata,
ActionRequest,
ActionType,
CheckpointType,
)
from app.services.safety.rollback.manager import (
FileCheckpoint,
RollbackManager,
TransactionContext,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def action_metadata() -> ActionMetadata:
"""Create standard action metadata for tests."""
return ActionMetadata(
agent_id="test-agent",
project_id="test-project",
session_id="test-session",
)
@pytest.fixture
def action_request(action_metadata: ActionMetadata) -> ActionRequest:
"""Create a standard action request for tests."""
return ActionRequest(
id="action-123",
action_type=ActionType.FILE_WRITE,
tool_name="file_write",
resource="/tmp/test_file.txt", # noqa: S108
metadata=action_metadata,
is_destructive=True,
)
@pytest_asyncio.fixture
async def rollback_manager() -> RollbackManager:
"""Create a RollbackManager for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
with patch("app.services.safety.rollback.manager.get_safety_config") as mock:
mock.return_value = MagicMock(
checkpoint_dir=tmpdir,
checkpoint_retention_hours=24,
)
manager = RollbackManager(checkpoint_dir=tmpdir, retention_hours=24)
yield manager
@pytest.fixture
def temp_dir() -> Path:
"""Create a temporary directory for file operations."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
# ============================================================================
# FileCheckpoint Tests
# ============================================================================
class TestFileCheckpoint:
"""Tests for the FileCheckpoint class."""
def test_file_checkpoint_creation(self) -> None:
"""Test creating a file checkpoint."""
fc = FileCheckpoint(
checkpoint_id="cp-123",
file_path="/path/to/file.txt",
original_content=b"original content",
existed=True,
)
assert fc.checkpoint_id == "cp-123"
assert fc.file_path == "/path/to/file.txt"
assert fc.original_content == b"original content"
assert fc.existed is True
assert fc.created_at is not None
def test_file_checkpoint_nonexistent_file(self) -> None:
"""Test checkpoint for non-existent file."""
fc = FileCheckpoint(
checkpoint_id="cp-123",
file_path="/path/to/new_file.txt",
original_content=None,
existed=False,
)
assert fc.original_content is None
assert fc.existed is False
# ============================================================================
# RollbackManager Tests
# ============================================================================
class TestRollbackManager:
"""Tests for the RollbackManager class."""
@pytest.mark.asyncio
async def test_create_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test creating a checkpoint."""
checkpoint = await rollback_manager.create_checkpoint(
action=action_request,
checkpoint_type=CheckpointType.FILE,
description="Test checkpoint",
)
assert checkpoint.id is not None
assert checkpoint.action_id == action_request.id
assert checkpoint.checkpoint_type == CheckpointType.FILE
assert checkpoint.description == "Test checkpoint"
assert checkpoint.expires_at is not None
assert checkpoint.is_valid is True
@pytest.mark.asyncio
async def test_create_checkpoint_default_description(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test checkpoint with default description."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
assert "file_write" in checkpoint.description
@pytest.mark.asyncio
async def test_checkpoint_file_exists(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing an existing file."""
# Create a file
test_file = temp_dir / "test.txt"
test_file.write_text("original content")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Verify checkpoint was stored
async with rollback_manager._lock:
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
assert len(file_checkpoints) == 1
assert file_checkpoints[0].existed is True
assert file_checkpoints[0].original_content == b"original content"
@pytest.mark.asyncio
async def test_checkpoint_file_not_exists(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing a non-existent file."""
test_file = temp_dir / "new_file.txt"
assert not test_file.exists()
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Verify checkpoint was stored
async with rollback_manager._lock:
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
assert len(file_checkpoints) == 1
assert file_checkpoints[0].existed is False
assert file_checkpoints[0].original_content is None
@pytest.mark.asyncio
async def test_checkpoint_files_multiple(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing multiple files."""
# Create files
file1 = temp_dir / "file1.txt"
file2 = temp_dir / "file2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_files(
checkpoint.id,
[str(file1), str(file2)],
)
async with rollback_manager._lock:
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
assert len(file_checkpoints) == 2
@pytest.mark.asyncio
async def test_rollback_restore_modified_file(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback restores modified file content."""
test_file = temp_dir / "test.txt"
test_file.write_text("original content")
# Create checkpoint
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Modify file
test_file.write_text("modified content")
assert test_file.read_text() == "modified content"
# Rollback
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is True
assert len(result.actions_rolled_back) == 1
assert test_file.read_text() == "original content"
@pytest.mark.asyncio
async def test_rollback_delete_new_file(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback deletes file that didn't exist before."""
test_file = temp_dir / "new_file.txt"
assert not test_file.exists()
# Create checkpoint before file exists
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Create the file
test_file.write_text("new content")
assert test_file.exists()
# Rollback
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is True
assert not test_file.exists()
@pytest.mark.asyncio
async def test_rollback_not_found(
self,
rollback_manager: RollbackManager,
) -> None:
"""Test rollback with non-existent checkpoint."""
with pytest.raises(RollbackError) as exc_info:
await rollback_manager.rollback("nonexistent-id")
assert "not found" in str(exc_info.value)
@pytest.mark.asyncio
async def test_rollback_invalid_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback with invalidated checkpoint."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Rollback once (invalidates checkpoint)
await rollback_manager.rollback(checkpoint.id)
# Try to rollback again
with pytest.raises(RollbackError) as exc_info:
await rollback_manager.rollback(checkpoint.id)
assert "no longer valid" in str(exc_info.value)
@pytest.mark.asyncio
async def test_discard_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test discarding a checkpoint."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
result = await rollback_manager.discard_checkpoint(checkpoint.id)
assert result is True
# Verify it's gone
cp = await rollback_manager.get_checkpoint(checkpoint.id)
assert cp is None
@pytest.mark.asyncio
async def test_discard_checkpoint_nonexistent(
self,
rollback_manager: RollbackManager,
) -> None:
"""Test discarding a non-existent checkpoint."""
result = await rollback_manager.discard_checkpoint("nonexistent-id")
assert result is False
@pytest.mark.asyncio
async def test_get_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test getting a checkpoint by ID."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
retrieved = await rollback_manager.get_checkpoint(checkpoint.id)
assert retrieved is not None
assert retrieved.id == checkpoint.id
@pytest.mark.asyncio
async def test_get_checkpoint_nonexistent(
self,
rollback_manager: RollbackManager,
) -> None:
"""Test getting a non-existent checkpoint."""
retrieved = await rollback_manager.get_checkpoint("nonexistent-id")
assert retrieved is None
@pytest.mark.asyncio
async def test_list_checkpoints(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test listing checkpoints."""
await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.create_checkpoint(action=action_request)
checkpoints = await rollback_manager.list_checkpoints()
assert len(checkpoints) == 2
@pytest.mark.asyncio
async def test_list_checkpoints_by_action(
self,
rollback_manager: RollbackManager,
action_metadata: ActionMetadata,
) -> None:
"""Test listing checkpoints filtered by action."""
action1 = ActionRequest(
id="action-1",
action_type=ActionType.FILE_WRITE,
metadata=action_metadata,
)
action2 = ActionRequest(
id="action-2",
action_type=ActionType.FILE_WRITE,
metadata=action_metadata,
)
await rollback_manager.create_checkpoint(action=action1)
await rollback_manager.create_checkpoint(action=action2)
checkpoints = await rollback_manager.list_checkpoints(action_id="action-1")
assert len(checkpoints) == 1
assert checkpoints[0].action_id == "action-1"
@pytest.mark.asyncio
async def test_list_checkpoints_excludes_expired(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test list_checkpoints excludes expired by default."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
# Manually expire it
async with rollback_manager._lock:
rollback_manager._checkpoints[checkpoint.id].expires_at = (
datetime.utcnow() - timedelta(hours=1)
)
checkpoints = await rollback_manager.list_checkpoints()
assert len(checkpoints) == 0
# With include_expired=True
checkpoints = await rollback_manager.list_checkpoints(include_expired=True)
assert len(checkpoints) == 1
@pytest.mark.asyncio
async def test_cleanup_expired(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test cleaning up expired checkpoints."""
# Create checkpoints
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
test_file = temp_dir / "test.txt"
test_file.write_text("content")
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Expire it
async with rollback_manager._lock:
rollback_manager._checkpoints[checkpoint.id].expires_at = (
datetime.utcnow() - timedelta(hours=1)
)
# Cleanup
count = await rollback_manager.cleanup_expired()
assert count == 1
# Verify it's gone
async with rollback_manager._lock:
assert checkpoint.id not in rollback_manager._checkpoints
assert checkpoint.id not in rollback_manager._file_checkpoints
# ============================================================================
# TransactionContext Tests
# ============================================================================
class TestTransactionContext:
"""Tests for the TransactionContext class."""
@pytest.mark.asyncio
async def test_context_creates_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test that entering context creates a checkpoint."""
async with TransactionContext(rollback_manager, action_request) as tx:
assert tx.checkpoint_id is not None
# Verify checkpoint exists
cp = await rollback_manager.get_checkpoint(tx.checkpoint_id)
assert cp is not None
@pytest.mark.asyncio
async def test_context_checkpoint_file(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing files through context."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_file(str(test_file))
# Modify file
test_file.write_text("modified")
# Manual rollback
result = await tx.rollback()
assert result is not None
assert result.success is True
assert test_file.read_text() == "original"
@pytest.mark.asyncio
async def test_context_checkpoint_files(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpointing multiple files through context."""
file1 = temp_dir / "file1.txt"
file2 = temp_dir / "file2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_files([str(file1), str(file2)])
cp_id = tx.checkpoint_id
async with rollback_manager._lock:
file_cps = rollback_manager._file_checkpoints.get(cp_id, [])
assert len(file_cps) == 2
tx.commit()
@pytest.mark.asyncio
async def test_context_auto_rollback_on_exception(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test auto-rollback when exception occurs."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
with pytest.raises(ValueError):
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
raise ValueError("Simulated error")
# Should have been rolled back
assert test_file.read_text() == "original"
@pytest.mark.asyncio
async def test_context_commit_prevents_rollback(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test that commit prevents auto-rollback."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
with pytest.raises(ValueError):
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
tx.commit()
raise ValueError("Simulated error after commit")
# Should NOT have been rolled back
assert test_file.read_text() == "modified"
@pytest.mark.asyncio
async def test_context_discards_checkpoint_on_commit(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test that checkpoint is discarded after successful commit."""
checkpoint_id = None
async with TransactionContext(rollback_manager, action_request) as tx:
checkpoint_id = tx.checkpoint_id
tx.commit()
# Checkpoint should be discarded
cp = await rollback_manager.get_checkpoint(checkpoint_id)
assert cp is None
@pytest.mark.asyncio
async def test_context_no_auto_rollback_when_disabled(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test that auto_rollback=False disables auto-rollback."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
with pytest.raises(ValueError):
async with TransactionContext(
rollback_manager,
action_request,
auto_rollback=False,
) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
raise ValueError("Simulated error")
# Should NOT have been rolled back
assert test_file.read_text() == "modified"
@pytest.mark.asyncio
async def test_context_manual_rollback(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test manual rollback within context."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
async with TransactionContext(rollback_manager, action_request) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
# Manual rollback
result = await tx.rollback()
assert result is not None
assert result.success is True
assert test_file.read_text() == "original"
@pytest.mark.asyncio
async def test_context_rollback_without_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test rollback when checkpoint is None."""
tx = TransactionContext(rollback_manager, action_request)
# Don't enter context, so _checkpoint is None
result = await tx.rollback()
assert result is None
@pytest.mark.asyncio
async def test_context_checkpoint_file_without_checkpoint(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test checkpoint_file when checkpoint is None (no-op)."""
tx = TransactionContext(rollback_manager, action_request)
test_file = temp_dir / "test.txt"
test_file.write_text("content")
# Should not raise - just a no-op
await tx.checkpoint_file(str(test_file))
await tx.checkpoint_files([str(test_file)])
# ============================================================================
# Edge Cases
# ============================================================================
class TestRollbackEdgeCases:
"""Edge cases that could reveal hidden bugs."""
@pytest.mark.asyncio
async def test_checkpoint_file_for_unknown_checkpoint(
self,
rollback_manager: RollbackManager,
temp_dir: Path,
) -> None:
"""Test checkpointing file for non-existent checkpoint."""
test_file = temp_dir / "test.txt"
test_file.write_text("content")
# Should create the list if it doesn't exist
await rollback_manager.checkpoint_file("unknown-checkpoint", str(test_file))
async with rollback_manager._lock:
assert "unknown-checkpoint" in rollback_manager._file_checkpoints
@pytest.mark.asyncio
async def test_rollback_with_partial_failure(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback when some files fail to restore."""
file1 = temp_dir / "file1.txt"
file1.write_text("original 1")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(file1))
# Add a file checkpoint with a path that will fail
async with rollback_manager._lock:
# Create a checkpoint for a file in a non-writable location
bad_fc = FileCheckpoint(
checkpoint_id=checkpoint.id,
file_path="/nonexistent/path/file.txt",
original_content=b"content",
existed=True,
)
rollback_manager._file_checkpoints[checkpoint.id].append(bad_fc)
# Rollback - partial failure expected
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is False
assert len(result.actions_rolled_back) == 1
assert len(result.failed_actions) == 1
assert "Failed to rollback" in result.error
@pytest.mark.asyncio
async def test_rollback_file_creates_parent_dirs(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test that rollback creates parent directories if needed."""
nested_file = temp_dir / "subdir" / "nested" / "file.txt"
nested_file.parent.mkdir(parents=True)
nested_file.write_text("original")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(nested_file))
# Delete the entire directory structure
nested_file.unlink()
(temp_dir / "subdir" / "nested").rmdir()
(temp_dir / "subdir").rmdir()
# Rollback should recreate
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is True
assert nested_file.exists()
assert nested_file.read_text() == "original"
@pytest.mark.asyncio
async def test_rollback_file_already_correct(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test rollback when file already has correct content."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
# Don't modify file - rollback should still succeed
result = await rollback_manager.rollback(checkpoint.id)
assert result.success is True
assert test_file.read_text() == "original"
@pytest.mark.asyncio
async def test_checkpoint_with_none_expires_at(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test list_checkpoints handles None expires_at."""
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
# Set expires_at to None
async with rollback_manager._lock:
rollback_manager._checkpoints[checkpoint.id].expires_at = None
# Should still be listed
checkpoints = await rollback_manager.list_checkpoints()
assert len(checkpoints) == 1
@pytest.mark.asyncio
async def test_auto_rollback_failure_logged(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
temp_dir: Path,
) -> None:
"""Test that auto-rollback failure is logged, not raised."""
test_file = temp_dir / "test.txt"
test_file.write_text("original")
with patch.object(
rollback_manager, "rollback", side_effect=Exception("Rollback failed!")
):
with patch("app.services.safety.rollback.manager.logger") as mock_logger:
with pytest.raises(ValueError):
async with TransactionContext(
rollback_manager, action_request
) as tx:
await tx.checkpoint_file(str(test_file))
test_file.write_text("modified")
raise ValueError("Original error")
# Rollback error should be logged
mock_logger.error.assert_called()
@pytest.mark.asyncio
async def test_multiple_checkpoints_same_action(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test creating multiple checkpoints for the same action."""
cp1 = await rollback_manager.create_checkpoint(action=action_request)
cp2 = await rollback_manager.create_checkpoint(action=action_request)
assert cp1.id != cp2.id
checkpoints = await rollback_manager.list_checkpoints(
action_id=action_request.id
)
assert len(checkpoints) == 2
@pytest.mark.asyncio
async def test_cleanup_expired_with_no_expired(
self,
rollback_manager: RollbackManager,
action_request: ActionRequest,
) -> None:
"""Test cleanup when no checkpoints are expired."""
await rollback_manager.create_checkpoint(action=action_request)
count = await rollback_manager.cleanup_expired()
assert count == 0
# Checkpoint should still exist
checkpoints = await rollback_manager.list_checkpoints()
assert len(checkpoints) == 1

View File

@@ -363,6 +363,365 @@ class TestValidationBatch:
assert results[1].decision == SafetyDecision.DENY
class TestValidationCache:
"""Tests for ValidationCache class."""
@pytest.mark.asyncio
async def test_cache_get_miss(self) -> None:
"""Test cache miss."""
from app.services.safety.validation.validator import ValidationCache
cache = ValidationCache(max_size=10, ttl_seconds=60)
result = await cache.get("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_cache_get_hit(self) -> None:
"""Test cache hit."""
from app.services.safety.models import ValidationResult
from app.services.safety.validation.validator import ValidationCache
cache = ValidationCache(max_size=10, ttl_seconds=60)
vr = ValidationResult(
action_id="action-1",
decision=SafetyDecision.ALLOW,
applied_rules=[],
reasons=["test"],
)
await cache.set("key1", vr)
result = await cache.get("key1")
assert result is not None
assert result.action_id == "action-1"
@pytest.mark.asyncio
async def test_cache_ttl_expiry(self) -> None:
"""Test cache TTL expiry."""
import time
from unittest.mock import patch
from app.services.safety.models import ValidationResult
from app.services.safety.validation.validator import ValidationCache
cache = ValidationCache(max_size=10, ttl_seconds=1)
vr = ValidationResult(
action_id="action-1",
decision=SafetyDecision.ALLOW,
applied_rules=[],
reasons=["test"],
)
await cache.set("key1", vr)
# Advance time past TTL
with patch("time.time", return_value=time.time() + 2):
result = await cache.get("key1")
assert result is None # Should be expired
@pytest.mark.asyncio
async def test_cache_eviction_on_full(self) -> None:
"""Test cache eviction when full."""
from app.services.safety.models import ValidationResult
from app.services.safety.validation.validator import ValidationCache
cache = ValidationCache(max_size=2, ttl_seconds=60)
vr1 = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
vr2 = ValidationResult(action_id="a2", decision=SafetyDecision.ALLOW)
vr3 = ValidationResult(action_id="a3", decision=SafetyDecision.ALLOW)
await cache.set("key1", vr1)
await cache.set("key2", vr2)
await cache.set("key3", vr3) # Should evict key1
# key1 should be evicted
assert await cache.get("key1") is None
assert await cache.get("key2") is not None
assert await cache.get("key3") is not None
@pytest.mark.asyncio
async def test_cache_update_existing_key(self) -> None:
"""Test updating existing key in cache."""
from app.services.safety.models import ValidationResult
from app.services.safety.validation.validator import ValidationCache
cache = ValidationCache(max_size=10, ttl_seconds=60)
vr1 = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
vr2 = ValidationResult(action_id="a1-updated", decision=SafetyDecision.DENY)
await cache.set("key1", vr1)
await cache.set("key1", vr2) # Should update, not add
result = await cache.get("key1")
assert result is not None
assert result.action_id == "a1" # Still old value since we move_to_end
@pytest.mark.asyncio
async def test_cache_clear(self) -> None:
"""Test clearing cache."""
from app.services.safety.models import ValidationResult
from app.services.safety.validation.validator import ValidationCache
cache = ValidationCache(max_size=10, ttl_seconds=60)
vr = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
await cache.set("key1", vr)
await cache.set("key2", vr)
await cache.clear()
assert await cache.get("key1") is None
assert await cache.get("key2") is None
class TestValidatorCaching:
"""Tests for validator caching functionality."""
@pytest.mark.asyncio
async def test_cache_hit(self) -> None:
"""Test that cache is used for repeated validations."""
validator = ActionValidator(cache_enabled=True, cache_ttl=60)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
resource="/tmp/test.txt", # noqa: S108
metadata=metadata,
)
# First call populates cache
result1 = await validator.validate(action)
# Second call should use cache
result2 = await validator.validate(action)
assert result1.decision == result2.decision
@pytest.mark.asyncio
async def test_clear_cache(self) -> None:
"""Test clearing the validation cache."""
validator = ActionValidator(cache_enabled=True)
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
metadata=metadata,
)
await validator.validate(action)
await validator.clear_cache()
# Cache should be empty now (no error)
result = await validator.validate(action)
assert result.decision == SafetyDecision.ALLOW
class TestRuleMatching:
"""Tests for rule matching edge cases."""
@pytest.mark.asyncio
async def test_action_type_mismatch(self) -> None:
"""Test that rule doesn't match when action type doesn't match."""
validator = ActionValidator(cache_enabled=False)
validator.add_rule(
ValidationRule(
name="file_only",
action_types=[ActionType.FILE_READ],
decision=SafetyDecision.DENY,
)
)
metadata = ActionMetadata(agent_id="test-agent")
action = ActionRequest(
action_type=ActionType.SHELL_COMMAND, # Different type
tool_name="shell_exec",
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
@pytest.mark.asyncio
async def test_tool_pattern_no_tool_name(self) -> None:
"""Test rule with tool pattern when action has no tool_name."""
validator = ActionValidator(cache_enabled=False)
validator.add_rule(
create_deny_rule(
name="deny_files",
tool_patterns=["file_*"],
)
)
metadata = ActionMetadata(agent_id="test-agent")
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name=None, # No tool name
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
@pytest.mark.asyncio
async def test_resource_pattern_no_resource(self) -> None:
"""Test rule with resource pattern when action has no resource."""
validator = ActionValidator(cache_enabled=False)
validator.add_rule(
create_deny_rule(
name="deny_secrets",
resource_patterns=["/secret/*"],
)
)
metadata = ActionMetadata(agent_id="test-agent")
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
resource=None, # No resource
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
@pytest.mark.asyncio
async def test_resource_pattern_no_match(self) -> None:
"""Test rule with resource pattern that doesn't match."""
validator = ActionValidator(cache_enabled=False)
validator.add_rule(
create_deny_rule(
name="deny_secrets",
resource_patterns=["/secret/*"],
)
)
metadata = ActionMetadata(agent_id="test-agent")
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
resource="/public/file.txt", # Doesn't match
metadata=metadata,
)
result = await validator.validate(action)
assert result.decision == SafetyDecision.ALLOW # Pattern didn't match
class TestPolicyLoading:
"""Tests for policy loading edge cases."""
@pytest.mark.asyncio
async def test_load_rules_from_policy_with_validation_rules(self) -> None:
"""Test loading policy with explicit validation rules."""
validator = ActionValidator(cache_enabled=False)
rule = ValidationRule(
name="policy_rule",
tool_patterns=["test_*"],
decision=SafetyDecision.DENY,
reason="From policy",
)
policy = SafetyPolicy(
name="test",
validation_rules=[rule],
require_approval_for=[], # Clear defaults
denied_tools=[], # Clear defaults
)
validator.load_rules_from_policy(policy)
assert len(validator._rules) == 1
assert validator._rules[0].name == "policy_rule"
@pytest.mark.asyncio
async def test_load_approval_all_pattern(self) -> None:
"""Test loading policy with * approval pattern (all actions)."""
validator = ActionValidator(cache_enabled=False)
policy = SafetyPolicy(
name="test",
require_approval_for=["*"], # All actions require approval
denied_tools=[], # Clear defaults
)
validator.load_rules_from_policy(policy)
approval_rules = [
r for r in validator._rules if r.decision == SafetyDecision.REQUIRE_APPROVAL
]
assert len(approval_rules) == 1
assert approval_rules[0].name == "require_approval_all"
assert approval_rules[0].action_types == list(ActionType)
@pytest.mark.asyncio
async def test_validate_with_policy_loads_rules(self) -> None:
"""Test that validate() loads rules from policy if none exist."""
validator = ActionValidator(cache_enabled=False)
policy = SafetyPolicy(
name="test",
denied_tools=["dangerous_*"],
)
metadata = ActionMetadata(agent_id="test-agent")
action = ActionRequest(
action_type=ActionType.SHELL_COMMAND,
tool_name="dangerous_exec",
metadata=metadata,
)
# Validate with policy - should load rules
result = await validator.validate(action, policy=policy)
assert result.decision == SafetyDecision.DENY
class TestCacheKeyGeneration:
"""Tests for cache key generation."""
def test_get_cache_key(self) -> None:
"""Test cache key generation."""
validator = ActionValidator(cache_enabled=True)
metadata = ActionMetadata(
agent_id="test-agent",
autonomy_level=AutonomyLevel.MILESTONE,
)
action = ActionRequest(
action_type=ActionType.FILE_READ,
tool_name="file_read",
resource="/tmp/test.txt", # noqa: S108
metadata=metadata,
)
key = validator._get_cache_key(action)
assert "file_read" in key
assert "file_read" in key
assert "/tmp/test.txt" in key # noqa: S108
assert "test-agent" in key
assert "milestone" in key
def test_get_cache_key_no_resource(self) -> None:
"""Test cache key generation without resource."""
validator = ActionValidator(cache_enabled=True)
metadata = ActionMetadata(agent_id="agent-1")
action = ActionRequest(
action_type=ActionType.SHELL_COMMAND,
tool_name="shell_exec",
resource=None,
metadata=metadata,
)
key = validator._get_cache_key(action)
# Should not error with None resource
assert "shell" in key
assert "agent-1" in key
class TestHelperFunctions:
"""Tests for rule creation helper functions."""

View File

@@ -48,6 +48,80 @@ services:
- app-network
restart: unless-stopped
# ==========================================================================
# MCP Servers - Model Context Protocol servers for AI agent capabilities
# ==========================================================================
mcp-llm-gateway:
# REPLACE THIS with your actual image from your container registry
image: YOUR_REGISTRY/YOUR_PROJECT_MCP_LLM_GATEWAY:latest
env_file:
- .env
environment:
- LLM_GATEWAY_HOST=0.0.0.0
- LLM_GATEWAY_PORT=8001
- REDIS_URL=redis://redis:6379/1
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ENVIRONMENT=production
depends_on:
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- app-network
restart: unless-stopped
deploy:
resources:
limits:
cpus: '2.0'
memory: 2G
reservations:
cpus: '0.5'
memory: 512M
mcp-knowledge-base:
# REPLACE THIS with your actual image from your container registry
image: YOUR_REGISTRY/YOUR_PROJECT_MCP_KNOWLEDGE_BASE:latest
env_file:
- .env
environment:
# KB_ prefix required by pydantic-settings config
- KB_HOST=0.0.0.0
- KB_PORT=8002
- KB_DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
- KB_REDIS_URL=redis://redis:6379/2
- KB_LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ENVIRONMENT=production
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- app-network
restart: unless-stopped
deploy:
resources:
limits:
cpus: '1.0'
memory: 1G
reservations:
cpus: '0.25'
memory: 256M
backend:
# REPLACE THIS with your actual image from your container registry
# Examples:
@@ -64,11 +138,18 @@ services:
- DEBUG=false
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
- REDIS_URL=redis://redis:6379/0
# MCP Server URLs
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
mcp-llm-gateway:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
networks:
- app-network
restart: unless-stopped
@@ -92,11 +173,18 @@ services:
- DATABASE_URL=${DATABASE_URL}
- REDIS_URL=redis://redis:6379/0
- CELERY_QUEUE=agent
# MCP Server URLs (agents need access to MCP)
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
mcp-llm-gateway:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
networks:
- app-network
restart: unless-stopped

View File

@@ -32,6 +32,70 @@ services:
networks:
- app-network
# ==========================================================================
# MCP Servers - Model Context Protocol servers for AI agent capabilities
# ==========================================================================
mcp-llm-gateway:
build:
context: ./mcp-servers/llm-gateway
dockerfile: Dockerfile
ports:
- "8001:8001"
env_file:
- .env
environment:
- LLM_GATEWAY_HOST=0.0.0.0
- LLM_GATEWAY_PORT=8001
- REDIS_URL=redis://redis:6379/1
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ENVIRONMENT=development
depends_on:
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- app-network
restart: unless-stopped
mcp-knowledge-base:
build:
context: ./mcp-servers/knowledge-base
dockerfile: Dockerfile
ports:
- "8002:8002"
env_file:
- .env
environment:
# KB_ prefix required by pydantic-settings config
- KB_HOST=0.0.0.0
- KB_PORT=8002
- KB_DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
- KB_REDIS_URL=redis://redis:6379/2
- KB_LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ENVIRONMENT=development
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- app-network
restart: unless-stopped
backend:
build:
context: ./backend
@@ -52,11 +116,18 @@ services:
- DEBUG=true
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
- REDIS_URL=redis://redis:6379/0
# MCP Server URLs
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
mcp-llm-gateway:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 10s
@@ -81,11 +152,18 @@ services:
- DATABASE_URL=${DATABASE_URL}
- REDIS_URL=redis://redis:6379/0
- CELERY_QUEUE=agent
# MCP Server URLs (agents need access to MCP)
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
mcp-llm-gateway:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
networks:
- app-network
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]

View File

@@ -32,6 +32,82 @@ services:
- app-network
restart: unless-stopped
# ==========================================================================
# MCP Servers - Model Context Protocol servers for AI agent capabilities
# ==========================================================================
mcp-llm-gateway:
build:
context: ./mcp-servers/llm-gateway
dockerfile: Dockerfile
env_file:
- .env
environment:
- LLM_GATEWAY_HOST=0.0.0.0
- LLM_GATEWAY_PORT=8001
- REDIS_URL=redis://redis:6379/1
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ENVIRONMENT=production
depends_on:
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- app-network
restart: unless-stopped
deploy:
resources:
limits:
cpus: '2.0'
memory: 2G
reservations:
cpus: '0.5'
memory: 512M
mcp-knowledge-base:
build:
context: ./mcp-servers/knowledge-base
dockerfile: Dockerfile
env_file:
- .env
environment:
# KB_ prefix required by pydantic-settings config
- KB_HOST=0.0.0.0
- KB_PORT=8002
- KB_DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
- KB_REDIS_URL=redis://redis:6379/2
- KB_LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ENVIRONMENT=production
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8002/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- app-network
restart: unless-stopped
deploy:
resources:
limits:
cpus: '1.0'
memory: 1G
reservations:
cpus: '0.25'
memory: 256M
backend:
build:
context: ./backend
@@ -48,11 +124,18 @@ services:
- DEBUG=false
- BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS}
- REDIS_URL=redis://redis:6379/0
# MCP Server URLs
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
mcp-llm-gateway:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
networks:
- app-network
restart: unless-stopped
@@ -75,11 +158,18 @@ services:
- DATABASE_URL=${DATABASE_URL}
- REDIS_URL=redis://redis:6379/0
- CELERY_QUEUE=agent
# MCP Server URLs (agents need access to MCP)
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
mcp-llm-gateway:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
networks:
- app-network
restart: unless-stopped

View File

@@ -205,6 +205,69 @@ test(frontend): add unit tests for ProjectDashboard
---
## Pre-Commit Hooks
The repository includes pre-commit hooks that enforce validation before commits on protected branches.
### Setup
Enable the hooks by configuring git to use the `.githooks` directory:
```bash
git config core.hooksPath .githooks
```
This only needs to be done once per clone.
### What the Hooks Do
When committing to **protected branches** (`main`, `dev`):
| Condition | Action |
|-----------|--------|
| Backend files changed | Runs `make validate` in `/backend` |
| Frontend files changed | Runs `npm run validate` in `/frontend` |
| No relevant changes | Skips validation |
If validation fails, the commit is blocked with an error message.
When committing to **feature branches**:
- Validation is skipped (allows WIP commits)
- A message reminds you to run validation manually if needed
### Why Protected Branches Only?
The hooks only enforce validation on `main` and `dev` for good reasons:
1. **Feature branches are for iteration** - WIP commits, experimentation, and rapid prototyping shouldn't be blocked
2. **Flexibility during development** - You can commit broken code to your feature branch while debugging
3. **PRs catch issues** - The merge process ensures validation passes before reaching protected branches
4. **Manual control** - You can always run `make validate` or `npm run validate` yourself
### Manual Validation
Even on feature branches, you should validate before creating a PR:
```bash
# Backend
cd backend && make validate
# Frontend
cd frontend && npm run validate
```
### Bypassing Hooks (Emergency Only)
In rare cases where you need to bypass the hook:
```bash
git commit --no-verify -m "message"
```
**Use sparingly** - this defeats the purpose of the hooks.
---
## Documentation Updates
- Keep `docs/architecture/IMPLEMENTATION_ROADMAP.md` updated
@@ -314,8 +377,11 @@ Do NOT use parallel agents when:
| Action | Command/Location |
|--------|-----------------|
| Create branch | `git checkout -b feature/<issue>-<desc>` |
| Enable pre-commit hooks | `git config core.hooksPath .githooks` |
| Run backend tests | `IS_TEST=True uv run pytest` |
| Run frontend tests | `npm test` |
| Backend validation | `cd backend && make validate` |
| Frontend validation | `cd frontend && npm run validate` |
| Check types (backend) | `uv run mypy src/` |
| Check types (frontend) | `npm run type-check` |
| Lint (backend) | `uv run ruff check src/` |

View File

@@ -386,10 +386,24 @@ describe('ActivityFeed', () => {
});
it('shows event count in group header', () => {
render(<ActivityFeed {...defaultProps} />);
// Create fresh "today" events to avoid timezone/day boundary issues
const todayEvents: ProjectEvent[] = [
createMockEvent({
id: 'today-event-1',
type: EventType.APPROVAL_REQUESTED,
timestamp: new Date().toISOString(),
}),
createMockEvent({
id: 'today-event-2',
type: EventType.AGENT_MESSAGE,
timestamp: new Date().toISOString(),
}),
];
render(<ActivityFeed {...defaultProps} events={todayEvents} />);
const todayGroup = screen.getByTestId('event-group-today');
// Today has 2 events in our mock data
// Today has 2 events
expect(within(todayGroup).getByText('2')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,79 @@
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
# Default target
help:
@echo "Knowledge Base MCP Server - Development Commands"
@echo ""
@echo "Setup:"
@echo " make install - Install production dependencies"
@echo " make install-dev - Install development dependencies"
@echo ""
@echo "Quality Checks:"
@echo " make lint - Run Ruff linter"
@echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff"
@echo " make type-check - Run mypy type checker"
@echo ""
@echo "Testing:"
@echo " make test - Run pytest"
@echo " make test-cov - Run pytest with coverage"
@echo ""
@echo "All-in-one:"
@echo " make validate - Run lint, type-check, and tests"
@echo ""
@echo "Running:"
@echo " make run - Run the server locally"
@echo ""
@echo "Cleanup:"
@echo " make clean - Remove cache and build artifacts"
# Setup
install:
@echo "Installing production dependencies..."
@uv pip install -e .
install-dev:
@echo "Installing development dependencies..."
@uv pip install -e ".[dev]"
# Quality checks
lint:
@echo "Running Ruff linter..."
@uv run ruff check .
lint-fix:
@echo "Running Ruff linter with auto-fix..."
@uv run ruff check --fix .
format:
@echo "Formatting code..."
@uv run ruff format .
type-check:
@echo "Running mypy..."
@uv run mypy . --ignore-missing-imports
# Testing
test:
@echo "Running tests..."
@uv run pytest tests/ -v
test-cov:
@echo "Running tests with coverage..."
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
# All-in-one validation
validate: lint type-check test
@echo "All validations passed!"
# Running
run:
@echo "Starting Knowledge Base server..."
@uv run python server.py
# Cleanup
clean:
@echo "Cleaning up..."
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find . -type f -name "*.pyc" -delete 2>/dev/null || true

View File

@@ -328,7 +328,7 @@ class CollectionManager:
"source_path": chunk.source_path or source_path,
"start_line": chunk.start_line,
"end_line": chunk.end_line,
"file_type": (chunk.file_type or file_type).value if (chunk.file_type or file_type) else None,
"file_type": effective_file_type.value if (effective_file_type := chunk.file_type or file_type) else None,
}
embeddings_data.append((
chunk.content,

View File

@@ -284,41 +284,40 @@ class DatabaseManager:
)
try:
async with self.acquire() as conn:
async with self.acquire() as conn, conn.transaction():
# Wrap in transaction for all-or-nothing batch semantics
async with conn.transaction():
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
source_path = metadata.get("source_path")
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
source_path = metadata.get("source_path")
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT DO NOTHING
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
ids.append(str(embedding_id))
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT DO NOTHING
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
ids.append(str(embedding_id))
logger.info(f"Stored {len(ids)} embeddings in batch")
return ids
@@ -566,59 +565,58 @@ class DatabaseManager:
)
try:
async with self.acquire() as conn:
async with self.acquire() as conn, conn.transaction():
# Use transaction for atomic replace
async with conn.transaction():
# First, delete existing embeddings for this source
delete_result = await conn.execute(
# First, delete existing embeddings for this source
delete_result = await conn.execute(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2 AND collection = $3
""",
project_id,
source_path,
collection,
)
deleted_count = int(delete_result.split()[-1])
# Then insert new embeddings
new_ids = []
for content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
DELETE FROM knowledge_embeddings
WHERE project_id = $1 AND source_path = $2 AND collection = $3
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id
""",
project_id,
source_path,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
deleted_count = int(delete_result.split()[-1])
if embedding_id:
new_ids.append(str(embedding_id))
# Then insert new embeddings
new_ids = []
for content, embedding, chunk_type, metadata in embeddings:
content_hash = self.compute_content_hash(content)
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
file_type = metadata.get("file_type")
embedding_id = await conn.fetchval(
"""
INSERT INTO knowledge_embeddings
(project_id, collection, content, embedding, chunk_type,
source_path, start_line, end_line, file_type, metadata,
content_hash, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id
""",
project_id,
collection,
content,
embedding,
chunk_type.value,
source_path,
start_line,
end_line,
file_type,
metadata,
content_hash,
expires_at,
)
if embedding_id:
new_ids.append(str(embedding_id))
logger.info(
f"Replaced source {source_path}: deleted {deleted_count}, "
f"inserted {len(new_ids)} embeddings"
)
return deleted_count, new_ids
logger.info(
f"Replaced source {source_path}: deleted {deleted_count}, "
f"inserted {len(new_ids)} embeddings"
)
return deleted_count, new_ids
except asyncpg.PostgresError as e:
logger.error(f"Replace source error: {e}")

View File

@@ -193,7 +193,7 @@ async def health_check() -> dict[str, Any]:
# Check Redis cache (non-critical - degraded without it)
try:
if _embeddings and _embeddings._redis:
await _embeddings._redis.ping()
await _embeddings._redis.ping() # type: ignore[misc]
status["dependencies"]["redis"] = "connected"
else:
status["dependencies"]["redis"] = "not initialized"

View File

@@ -1,8 +1,7 @@
"""Tests for server module and MCP tools."""
import json
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient

View File

@@ -1,39 +1,25 @@
# Syndarix LLM Gateway MCP Server
# Multi-stage build for minimal image size
FROM python:3.12-slim
# Build stage
FROM python:3.12-slim AS builder
WORKDIR /app
# Install system dependencies (needed for tiktoken regex compilation)
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Install uv for fast package management
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
WORKDIR /app
# Copy dependency files
# Copy project files
COPY pyproject.toml ./
COPY *.py ./
# Create virtual environment and install dependencies
RUN uv venv /app/.venv
ENV PATH="/app/.venv/bin:$PATH"
RUN uv pip install -e .
# Runtime stage
FROM python:3.12-slim AS runtime
# Install dependencies to system Python
RUN uv pip install --system --no-cache .
# Create non-root user for security
RUN groupadd --gid 1000 appgroup && \
useradd --uid 1000 --gid appgroup --shell /bin/bash --create-home appuser
WORKDIR /app
# Copy virtual environment from builder
COPY --from=builder /app/.venv /app/.venv
ENV PATH="/app/.venv/bin:$PATH"
# Copy application code
COPY --chown=appuser:appgroup . .
# Switch to non-root user
RUN useradd --create-home --shell /bin/bash appuser
USER appuser
# Environment variables
@@ -47,7 +33,7 @@ EXPOSE 8001
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()" || exit 1
CMD python -c "import httpx; httpx.get('http://localhost:8001/health').raise_for_status()"
# Run the server
CMD ["python", "server.py"]

View File

@@ -0,0 +1,79 @@
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
# Default target
help:
@echo "LLM Gateway MCP Server - Development Commands"
@echo ""
@echo "Setup:"
@echo " make install - Install production dependencies"
@echo " make install-dev - Install development dependencies"
@echo ""
@echo "Quality Checks:"
@echo " make lint - Run Ruff linter"
@echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff"
@echo " make type-check - Run mypy type checker"
@echo ""
@echo "Testing:"
@echo " make test - Run pytest"
@echo " make test-cov - Run pytest with coverage"
@echo ""
@echo "All-in-one:"
@echo " make validate - Run lint, type-check, and tests"
@echo ""
@echo "Running:"
@echo " make run - Run the server locally"
@echo ""
@echo "Cleanup:"
@echo " make clean - Remove cache and build artifacts"
# Setup
install:
@echo "Installing production dependencies..."
@uv pip install -e .
install-dev:
@echo "Installing development dependencies..."
@uv pip install -e ".[dev]"
# Quality checks
lint:
@echo "Running Ruff linter..."
@uv run ruff check .
lint-fix:
@echo "Running Ruff linter with auto-fix..."
@uv run ruff check --fix .
format:
@echo "Formatting code..."
@uv run ruff format .
type-check:
@echo "Running mypy..."
@uv run mypy . --ignore-missing-imports
# Testing
test:
@echo "Running tests..."
@uv run pytest tests/ -v
test-cov:
@echo "Running tests with coverage..."
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
# All-in-one validation
validate: lint type-check test
@echo "All validations passed!"
# Running
run:
@echo "Starting LLM Gateway server..."
@uv run python server.py
# Cleanup
clean:
@echo "Cleaning up..."
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find . -type f -name "*.pyc" -delete 2>/dev/null || true

View File

@@ -110,14 +110,13 @@ class CircuitBreaker:
"""
if self._state == CircuitState.OPEN:
time_in_open = time.time() - self._stats.state_changed_at
if time_in_open >= self.recovery_timeout:
# Only transition if still in OPEN state (double-check)
if self._state == CircuitState.OPEN:
self._transition_to(CircuitState.HALF_OPEN)
logger.info(
f"Circuit {self.name} transitioned to HALF_OPEN "
f"after {time_in_open:.1f}s"
)
# Double-check state after time calculation (for thread safety)
if time_in_open >= self.recovery_timeout and self._state == CircuitState.OPEN:
self._transition_to(CircuitState.HALF_OPEN)
logger.info(
f"Circuit {self.name} transitioned to HALF_OPEN "
f"after {time_in_open:.1f}s"
)
def _transition_to(self, new_state: CircuitState) -> None:
"""Transition to a new state."""