forked from cardosofelipe/pragma-stack
Compare commits
47 Commits
e5975fa5d0
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b149b8a52 | ||
|
|
ad0c06851d | ||
|
|
49359b1416 | ||
|
|
911d950c15 | ||
|
|
b2a3ac60e0 | ||
|
|
dea092e1bb | ||
|
|
4154dd5268 | ||
|
|
db12937495 | ||
|
|
81e1456631 | ||
|
|
58e78d8700 | ||
|
|
5e80139afa | ||
|
|
60ebeaa582 | ||
|
|
758052dcff | ||
|
|
1628eacf2b | ||
|
|
2bea057fb1 | ||
|
|
9e54f16e56 | ||
|
|
96e6400bd8 | ||
|
|
6c7b72f130 | ||
|
|
027ebfc332 | ||
|
|
c2466ab401 | ||
|
|
7828d35e06 | ||
|
|
6b07e62f00 | ||
|
|
0d2005ddcb | ||
|
|
dfa75e682e | ||
|
|
22ecb5e989 | ||
|
|
2ab69f8561 | ||
|
|
95342cc94d | ||
|
|
f6194b3e19 | ||
|
|
6bb376a336 | ||
|
|
cd7a9ccbdf | ||
|
|
953af52d0e | ||
|
|
e6e98d4ed1 | ||
|
|
ca5f5e3383 | ||
|
|
d0fc7f37ff | ||
|
|
18d717e996 | ||
|
|
f482559e15 | ||
|
|
6e8b0b022a | ||
|
|
746fb7b181 | ||
|
|
caf283bed2 | ||
|
|
520c06175e | ||
|
|
065e43c5a9 | ||
|
|
c8b88dadc3 | ||
|
|
015f2de6c6 | ||
|
|
f36bfb3781 | ||
|
|
ef659cd72d | ||
|
|
728edd1453 | ||
|
|
498c0a0e94 |
61
.githooks/pre-commit
Executable file
61
.githooks/pre-commit
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
# Pre-commit hook to enforce validation before commits on protected branches
|
||||
# Install: git config core.hooksPath .githooks
|
||||
|
||||
set -e
|
||||
|
||||
# Get the current branch name
|
||||
BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||
|
||||
# Protected branches that require validation
|
||||
PROTECTED_BRANCHES="main dev"
|
||||
|
||||
# Check if we're on a protected branch
|
||||
is_protected() {
|
||||
for branch in $PROTECTED_BRANCHES; do
|
||||
if [ "$BRANCH" = "$branch" ]; then
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
if is_protected; then
|
||||
echo "🔒 Committing to protected branch '$BRANCH' - running validation..."
|
||||
|
||||
# Check if we have backend changes
|
||||
if git diff --cached --name-only | grep -q "^backend/"; then
|
||||
echo "📦 Backend changes detected - running make validate..."
|
||||
cd backend
|
||||
if ! make validate; then
|
||||
echo ""
|
||||
echo "❌ Backend validation failed!"
|
||||
echo " Please fix the issues and try again."
|
||||
echo " Run 'cd backend && make validate' to see errors."
|
||||
exit 1
|
||||
fi
|
||||
cd ..
|
||||
echo "✅ Backend validation passed!"
|
||||
fi
|
||||
|
||||
# Check if we have frontend changes
|
||||
if git diff --cached --name-only | grep -q "^frontend/"; then
|
||||
echo "🎨 Frontend changes detected - running npm run validate..."
|
||||
cd frontend
|
||||
if ! npm run validate 2>/dev/null; then
|
||||
echo ""
|
||||
echo "❌ Frontend validation failed!"
|
||||
echo " Please fix the issues and try again."
|
||||
echo " Run 'cd frontend && npm run validate' to see errors."
|
||||
exit 1
|
||||
fi
|
||||
cd ..
|
||||
echo "✅ Frontend validation passed!"
|
||||
fi
|
||||
|
||||
echo "🎉 All validations passed! Proceeding with commit..."
|
||||
else
|
||||
echo "📝 Committing to feature branch '$BRANCH' - skipping validation (run manually if needed)"
|
||||
fi
|
||||
|
||||
exit 0
|
||||
31
CLAUDE.md
31
CLAUDE.md
@@ -83,6 +83,37 @@ docs/
|
||||
3. **Testing Required**: All code must be tested, aim for >90% coverage
|
||||
4. **Code Review**: Must pass multi-agent review before merge
|
||||
5. **No Direct Commits**: Never commit directly to `main` or `dev`
|
||||
6. **Stack Verification**: ALWAYS run the full stack before considering work done (see below)
|
||||
|
||||
### CRITICAL: Stack Verification Before Merge
|
||||
|
||||
**This is NON-NEGOTIABLE. A feature with 100% test coverage that crashes on startup is WORTHLESS.**
|
||||
|
||||
Before considering ANY issue complete:
|
||||
|
||||
```bash
|
||||
# 1. Start the dev stack
|
||||
make dev
|
||||
|
||||
# 2. Wait for backend to be healthy, check logs
|
||||
docker compose -f docker-compose.dev.yml logs backend --tail=100
|
||||
|
||||
# 3. Start frontend
|
||||
cd frontend && npm run dev
|
||||
|
||||
# 4. Verify both are running without errors
|
||||
```
|
||||
|
||||
**The issue is NOT done if:**
|
||||
- Backend crashes on startup (import errors, missing dependencies)
|
||||
- Frontend fails to compile or render
|
||||
- Health checks fail
|
||||
- Any error appears in logs
|
||||
|
||||
**Why this matters:**
|
||||
- Tests run in isolation and may pass despite broken imports
|
||||
- Docker builds cache layers and may hide dependency issues
|
||||
- A single `ModuleNotFoundError` renders all test coverage meaningless
|
||||
|
||||
### Common Commands
|
||||
|
||||
|
||||
92
Makefile
92
Makefile
@@ -1,18 +1,31 @@
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "FastAPI + Next.js Full-Stack Template"
|
||||
@echo "Syndarix - AI-Powered Software Consulting Agency"
|
||||
@echo ""
|
||||
@echo "Development:"
|
||||
@echo " make dev - Start backend + db (frontend runs separately)"
|
||||
@echo " make dev - Start backend + db + MCP servers (frontend runs separately)"
|
||||
@echo " make dev-full - Start all services including frontend"
|
||||
@echo " make down - Stop all services"
|
||||
@echo " make logs-dev - Follow dev container logs"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run all tests (backend + MCP servers)"
|
||||
@echo " make test-backend - Run backend tests only"
|
||||
@echo " make test-mcp - Run MCP server tests only"
|
||||
@echo " make test-frontend - Run frontend tests only"
|
||||
@echo " make test-cov - Run all tests with coverage reports"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo ""
|
||||
@echo "Validation:"
|
||||
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
|
||||
@echo " make validate-all - Validate everything including frontend"
|
||||
@echo ""
|
||||
@echo "Database:"
|
||||
@echo " make drop-db - Drop and recreate empty database"
|
||||
@echo " make reset-db - Drop database and apply all migrations"
|
||||
@@ -28,8 +41,10 @@ help:
|
||||
@echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)"
|
||||
@echo ""
|
||||
@echo "Subdirectory commands:"
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
||||
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
|
||||
# ============================================================================
|
||||
# Development
|
||||
@@ -99,3 +114,72 @@ clean:
|
||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||
clean-slate:
|
||||
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
||||
|
||||
# ============================================================================
|
||||
# Testing
|
||||
# ============================================================================
|
||||
|
||||
test: test-backend test-mcp
|
||||
@echo ""
|
||||
@echo "All tests passed!"
|
||||
|
||||
test-backend:
|
||||
@echo "Running backend tests..."
|
||||
@cd backend && IS_TEST=True uv run pytest tests/ -v
|
||||
|
||||
test-mcp:
|
||||
@echo "Running MCP server tests..."
|
||||
@echo ""
|
||||
@echo "=== LLM Gateway ==="
|
||||
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
||||
|
||||
test-frontend:
|
||||
@echo "Running frontend tests..."
|
||||
@cd frontend && npm test
|
||||
|
||||
test-all: test test-frontend
|
||||
@echo ""
|
||||
@echo "All tests (backend + MCP + frontend) passed!"
|
||||
|
||||
test-cov:
|
||||
@echo "Running all tests with coverage..."
|
||||
@echo ""
|
||||
@echo "=== Backend Coverage ==="
|
||||
@cd backend && IS_TEST=True uv run pytest tests/ -v --cov=app --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== LLM Gateway Coverage ==="
|
||||
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base Coverage ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
|
||||
test-integration:
|
||||
@echo "Running MCP integration tests..."
|
||||
@echo "Note: Requires running stack (make dev first)"
|
||||
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# Validation (lint + type-check + test)
|
||||
# ============================================================================
|
||||
|
||||
validate:
|
||||
@echo "Validating backend..."
|
||||
@cd backend && make validate
|
||||
@echo ""
|
||||
@echo "Validating LLM Gateway..."
|
||||
@cd mcp-servers/llm-gateway && make validate
|
||||
@echo ""
|
||||
@echo "Validating Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make validate
|
||||
@echo ""
|
||||
@echo "All validations passed!"
|
||||
|
||||
validate-all: validate
|
||||
@echo ""
|
||||
@echo "Validating frontend..."
|
||||
@cd frontend && npm run validate
|
||||
@echo ""
|
||||
@echo "Full validation passed!"
|
||||
|
||||
@@ -7,7 +7,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
UV_NO_CACHE=1 \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apt-get update && \
|
||||
@@ -20,7 +23,7 @@ RUN apt-get update && \
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install dependencies using uv (development mode with dev dependencies)
|
||||
# Install dependencies using uv into /opt/venv (outside /app to survive bind mounts)
|
||||
RUN uv sync --extra dev --frozen
|
||||
|
||||
# Copy application code
|
||||
@@ -45,7 +48,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
UV_NO_CACHE=1 \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apt-get update && \
|
||||
@@ -58,7 +64,7 @@ RUN apt-get update && \
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install only production dependencies using uv (no dev dependencies)
|
||||
# Install only production dependencies using uv into /opt/venv
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# Copy application code
|
||||
@@ -67,7 +73,7 @@ COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
# Set ownership to non-root user
|
||||
RUN chown -R appuser:appuser /app
|
||||
RUN chown -R appuser:appuser /app /opt/venv
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
@@ -77,4 +83,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all test-integration
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -22,6 +22,7 @@ help:
|
||||
@echo " make test-cov - Run pytest with coverage report"
|
||||
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo " make test-all - Run all tests (unit + E2E)"
|
||||
@echo " make check-docker - Check if Docker is available"
|
||||
@echo ""
|
||||
@@ -82,6 +83,15 @@ test-cov:
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||
|
||||
# ============================================================================
|
||||
# Integration Testing (requires running stack: make dev)
|
||||
# ============================================================================
|
||||
|
||||
test-integration:
|
||||
@echo "🧪 Running MCP integration tests..."
|
||||
@echo "Note: Requires running stack (make dev from project root)"
|
||||
@RUN_INTEGRATION_TESTS=true IS_TEST=True PYTHONPATH=. uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# E2E Testing (requires Docker)
|
||||
# ============================================================================
|
||||
|
||||
@@ -5,6 +5,7 @@ from app.api.routes import (
|
||||
agent_types,
|
||||
agents,
|
||||
auth,
|
||||
context,
|
||||
events,
|
||||
issues,
|
||||
mcp,
|
||||
@@ -35,6 +36,9 @@ api_router.include_router(events.router, tags=["Events"])
|
||||
# MCP (Model Context Protocol) router
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
|
||||
|
||||
# Context Management Engine router
|
||||
api_router.include_router(context.router, prefix="/context", tags=["Context"])
|
||||
|
||||
# Syndarix domain routers
|
||||
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
||||
api_router.include_router(
|
||||
|
||||
411
backend/app/api/routes/context.py
Normal file
411
backend/app/api/routes/context.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Context Management API Endpoints.
|
||||
|
||||
Provides REST endpoints for context assembly and optimization
|
||||
for LLM requests using the ContextEngine.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.models.user import User
|
||||
from app.services.context import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
ContextEngine,
|
||||
ContextSettings,
|
||||
create_context_engine,
|
||||
get_context_settings,
|
||||
)
|
||||
from app.services.mcp import MCPClientManager, get_mcp_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton Engine Management
|
||||
# ============================================================================
|
||||
|
||||
_context_engine: ContextEngine | None = None
|
||||
|
||||
|
||||
def _get_or_create_engine(
|
||||
mcp: MCPClientManager,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> ContextEngine:
|
||||
"""Get or create the singleton ContextEngine."""
|
||||
global _context_engine
|
||||
if _context_engine is None:
|
||||
_context_engine = create_context_engine(
|
||||
mcp_manager=mcp,
|
||||
redis=None, # Optional: add Redis caching later
|
||||
settings=settings or get_context_settings(),
|
||||
)
|
||||
logger.info("ContextEngine initialized")
|
||||
else:
|
||||
# Ensure MCP manager is up to date
|
||||
_context_engine.set_mcp_manager(mcp)
|
||||
return _context_engine
|
||||
|
||||
|
||||
async def get_context_engine(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ContextEngine:
|
||||
"""FastAPI dependency to get the ContextEngine."""
|
||||
return _get_or_create_engine(mcp)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ConversationTurn(BaseModel):
|
||||
"""A single conversation turn."""
|
||||
|
||||
role: str = Field(..., description="Role: 'user' or 'assistant'")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""A tool execution result."""
|
||||
|
||||
tool_name: str = Field(..., description="Name of the tool")
|
||||
content: str | dict[str, Any] = Field(..., description="Tool result content")
|
||||
status: str = Field(default="success", description="Execution status")
|
||||
|
||||
|
||||
class AssembleContextRequest(BaseModel):
|
||||
"""Request to assemble context for an LLM request."""
|
||||
|
||||
project_id: str = Field(..., description="Project identifier")
|
||||
agent_id: str = Field(..., description="Agent identifier")
|
||||
query: str = Field(..., description="User's query or current request")
|
||||
model: str = Field(
|
||||
default="claude-3-sonnet",
|
||||
description="Target model name",
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
None,
|
||||
description="Maximum context tokens (uses model default if None)",
|
||||
)
|
||||
system_prompt: str | None = Field(
|
||||
None,
|
||||
description="System prompt/instructions",
|
||||
)
|
||||
task_description: str | None = Field(
|
||||
None,
|
||||
description="Current task description",
|
||||
)
|
||||
knowledge_query: str | None = Field(
|
||||
None,
|
||||
description="Query for knowledge base search",
|
||||
)
|
||||
knowledge_limit: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Max number of knowledge results",
|
||||
)
|
||||
conversation_history: list[ConversationTurn] | None = Field(
|
||||
None,
|
||||
description="Previous conversation turns",
|
||||
)
|
||||
tool_results: list[ToolResult] | None = Field(
|
||||
None,
|
||||
description="Tool execution results to include",
|
||||
)
|
||||
compress: bool = Field(
|
||||
default=True,
|
||||
description="Whether to apply compression",
|
||||
)
|
||||
use_cache: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use caching",
|
||||
)
|
||||
|
||||
|
||||
class AssembledContextResponse(BaseModel):
|
||||
"""Response containing assembled context."""
|
||||
|
||||
content: str = Field(..., description="Assembled context content")
|
||||
total_tokens: int = Field(..., description="Total token count")
|
||||
context_count: int = Field(..., description="Number of context items included")
|
||||
compressed: bool = Field(..., description="Whether compression was applied")
|
||||
budget_used_percent: float = Field(
|
||||
...,
|
||||
description="Percentage of token budget used",
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata",
|
||||
)
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
"""Request to count tokens in content."""
|
||||
|
||||
content: str = Field(..., description="Content to count tokens in")
|
||||
model: str | None = Field(
|
||||
None,
|
||||
description="Model for model-specific tokenization",
|
||||
)
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
"""Response containing token count."""
|
||||
|
||||
token_count: int = Field(..., description="Number of tokens")
|
||||
model: str | None = Field(None, description="Model used for counting")
|
||||
|
||||
|
||||
class BudgetInfoResponse(BaseModel):
|
||||
"""Response containing budget information for a model."""
|
||||
|
||||
model: str = Field(..., description="Model name")
|
||||
total_tokens: int = Field(..., description="Total token budget")
|
||||
system_tokens: int = Field(..., description="Tokens reserved for system")
|
||||
knowledge_tokens: int = Field(..., description="Tokens for knowledge")
|
||||
conversation_tokens: int = Field(..., description="Tokens for conversation")
|
||||
tool_tokens: int = Field(..., description="Tokens for tool results")
|
||||
response_reserve: int = Field(..., description="Tokens reserved for response")
|
||||
|
||||
|
||||
class ContextEngineStatsResponse(BaseModel):
|
||||
"""Response containing engine statistics."""
|
||||
|
||||
cache: dict[str, Any] = Field(..., description="Cache statistics")
|
||||
settings: dict[str, Any] = Field(..., description="Current settings")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(..., description="Health status")
|
||||
mcp_connected: bool = Field(..., description="Whether MCP is connected")
|
||||
cache_enabled: bool = Field(..., description="Whether caching is enabled")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
response_model=HealthResponse,
|
||||
summary="Context Engine Health",
|
||||
description="Check health status of the context engine.",
|
||||
)
|
||||
async def health_check(
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> HealthResponse:
|
||||
"""Check context engine health."""
|
||||
stats = await engine.get_stats()
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
mcp_connected=engine._mcp is not None,
|
||||
cache_enabled=stats.get("settings", {}).get("cache_enabled", False),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assemble",
|
||||
response_model=AssembledContextResponse,
|
||||
summary="Assemble Context",
|
||||
description="Assemble optimized context for an LLM request.",
|
||||
)
|
||||
async def assemble_context(
|
||||
request: AssembleContextRequest,
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> AssembledContextResponse:
|
||||
"""
|
||||
Assemble optimized context for an LLM request.
|
||||
|
||||
This endpoint gathers context from various sources, scores and ranks them,
|
||||
compresses if needed, and formats for the target model.
|
||||
"""
|
||||
logger.info(
|
||||
"Context assembly for project=%s agent=%s by user=%s",
|
||||
request.project_id,
|
||||
request.agent_id,
|
||||
current_user.id,
|
||||
)
|
||||
|
||||
# Convert conversation history to dict format
|
||||
conversation_history = None
|
||||
if request.conversation_history:
|
||||
conversation_history = [
|
||||
{"role": turn.role, "content": turn.content}
|
||||
for turn in request.conversation_history
|
||||
]
|
||||
|
||||
# Convert tool results to dict format
|
||||
tool_results = None
|
||||
if request.tool_results:
|
||||
tool_results = [
|
||||
{
|
||||
"tool_name": tr.tool_name,
|
||||
"content": tr.content,
|
||||
"status": tr.status,
|
||||
}
|
||||
for tr in request.tool_results
|
||||
]
|
||||
|
||||
try:
|
||||
result = await engine.assemble_context(
|
||||
project_id=request.project_id,
|
||||
agent_id=request.agent_id,
|
||||
query=request.query,
|
||||
model=request.model,
|
||||
max_tokens=request.max_tokens,
|
||||
system_prompt=request.system_prompt,
|
||||
task_description=request.task_description,
|
||||
knowledge_query=request.knowledge_query,
|
||||
knowledge_limit=request.knowledge_limit,
|
||||
conversation_history=conversation_history,
|
||||
tool_results=tool_results,
|
||||
compress=request.compress,
|
||||
use_cache=request.use_cache,
|
||||
)
|
||||
|
||||
# Calculate budget usage percentage
|
||||
budget = await engine.get_budget_for_model(request.model, request.max_tokens)
|
||||
budget_used_percent = (result.total_tokens / budget.total) * 100
|
||||
|
||||
# Check if compression was applied (from metadata if available)
|
||||
was_compressed = result.metadata.get("compressed_contexts", 0) > 0
|
||||
|
||||
return AssembledContextResponse(
|
||||
content=result.content,
|
||||
total_tokens=result.total_tokens,
|
||||
context_count=result.context_count,
|
||||
compressed=was_compressed,
|
||||
budget_used_percent=round(budget_used_percent, 2),
|
||||
metadata={
|
||||
"model": request.model,
|
||||
"query": request.query,
|
||||
"knowledge_included": bool(request.knowledge_query),
|
||||
"conversation_turns": len(request.conversation_history or []),
|
||||
"excluded_count": result.excluded_count,
|
||||
"assembly_time_ms": result.assembly_time_ms,
|
||||
},
|
||||
)
|
||||
|
||||
except AssemblyTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=f"Context assembly timed out: {e}",
|
||||
) from e
|
||||
except BudgetExceededError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail=f"Token budget exceeded: {e}",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Context assembly failed")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Context assembly failed: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/count-tokens",
|
||||
response_model=TokenCountResponse,
|
||||
summary="Count Tokens",
|
||||
description="Count tokens in content using the LLM Gateway.",
|
||||
)
|
||||
async def count_tokens(
|
||||
request: TokenCountRequest,
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> TokenCountResponse:
|
||||
"""Count tokens in content."""
|
||||
try:
|
||||
count = await engine.count_tokens(
|
||||
content=request.content,
|
||||
model=request.model,
|
||||
)
|
||||
return TokenCountResponse(
|
||||
token_count=count,
|
||||
model=request.model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Token counting failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token counting failed: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/budget/{model}",
|
||||
response_model=BudgetInfoResponse,
|
||||
summary="Get Token Budget",
|
||||
description="Get token budget allocation for a specific model.",
|
||||
)
|
||||
async def get_budget(
|
||||
model: str,
|
||||
max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None,
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> BudgetInfoResponse:
|
||||
"""Get token budget information for a model."""
|
||||
budget = await engine.get_budget_for_model(model, max_tokens)
|
||||
return BudgetInfoResponse(
|
||||
model=model,
|
||||
total_tokens=budget.total,
|
||||
system_tokens=budget.system,
|
||||
knowledge_tokens=budget.knowledge,
|
||||
conversation_tokens=budget.conversation,
|
||||
tool_tokens=budget.tools,
|
||||
response_reserve=budget.response_reserve,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=ContextEngineStatsResponse,
|
||||
summary="Engine Statistics",
|
||||
description="Get context engine statistics and configuration.",
|
||||
)
|
||||
async def get_stats(
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> ContextEngineStatsResponse:
|
||||
"""Get engine statistics."""
|
||||
stats = await engine.get_stats()
|
||||
return ContextEngineStatsResponse(
|
||||
cache=stats.get("cache", {}),
|
||||
settings=stats.get("settings", {}),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/invalidate",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Invalidate Cache (Admin Only)",
|
||||
description="Invalidate context cache entries.",
|
||||
)
|
||||
async def invalidate_cache(
|
||||
project_id: Annotated[
|
||||
str | None, Query(description="Project to invalidate")
|
||||
] = None,
|
||||
pattern: Annotated[str | None, Query(description="Pattern to match")] = None,
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> None:
|
||||
"""Invalidate cache entries."""
|
||||
logger.info(
|
||||
"Cache invalidation by user %s: project=%s pattern=%s",
|
||||
current_user.id,
|
||||
project_id,
|
||||
pattern,
|
||||
)
|
||||
await engine.invalidate_cache(project_id=project_id, pattern=pattern)
|
||||
@@ -74,7 +74,9 @@ class ToolInfoResponse(BaseModel):
|
||||
name: str = Field(..., description="Tool name")
|
||||
description: str | None = Field(None, description="Tool description")
|
||||
server_name: str | None = Field(None, description="Server providing the tool")
|
||||
input_schema: dict[str, Any] | None = Field(None, description="JSON schema for input")
|
||||
input_schema: dict[str, Any] | None = Field(
|
||||
None, description="JSON schema for input"
|
||||
)
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
|
||||
178
backend/app/services/context/__init__.py
Normal file
178
backend/app/services/context/__init__.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Context Management Engine
|
||||
|
||||
Sophisticated context assembly and optimization for LLM requests.
|
||||
Provides intelligent context selection, token budget management,
|
||||
and model-specific formatting.
|
||||
|
||||
Usage:
|
||||
from app.services.context import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
SystemContext,
|
||||
KnowledgeContext,
|
||||
ConversationContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
TokenBudget,
|
||||
BudgetAllocator,
|
||||
TokenCalculator,
|
||||
)
|
||||
|
||||
# Get settings
|
||||
settings = get_context_settings()
|
||||
|
||||
# Create budget for a model
|
||||
allocator = BudgetAllocator(settings)
|
||||
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||
|
||||
# Create context instances
|
||||
system_ctx = SystemContext.create_persona(
|
||||
name="Code Assistant",
|
||||
description="You are a helpful code assistant.",
|
||||
capabilities=["Write code", "Debug issues"],
|
||||
)
|
||||
"""
|
||||
|
||||
# Budget Management
|
||||
# Adapters
|
||||
from .adapters import (
|
||||
ClaudeAdapter,
|
||||
DefaultAdapter,
|
||||
ModelAdapter,
|
||||
OpenAIAdapter,
|
||||
get_adapter,
|
||||
)
|
||||
|
||||
# Assembly
|
||||
from .assembly import (
|
||||
ContextPipeline,
|
||||
PipelineMetrics,
|
||||
)
|
||||
from .budget import (
|
||||
BudgetAllocator,
|
||||
TokenBudget,
|
||||
TokenCalculator,
|
||||
)
|
||||
|
||||
# Cache
|
||||
from .cache import ContextCache
|
||||
|
||||
# Compression
|
||||
from .compression import (
|
||||
ContextCompressor,
|
||||
TruncationResult,
|
||||
TruncationStrategy,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
get_default_settings,
|
||||
reset_context_settings,
|
||||
)
|
||||
|
||||
# Engine
|
||||
from .engine import ContextEngine, create_context_engine
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
CacheError,
|
||||
CompressionError,
|
||||
ContextError,
|
||||
ContextNotFoundError,
|
||||
FormattingError,
|
||||
InvalidContextError,
|
||||
ScoringError,
|
||||
TokenCountError,
|
||||
)
|
||||
|
||||
# Prioritization
|
||||
from .prioritization import (
|
||||
ContextRanker,
|
||||
RankingResult,
|
||||
)
|
||||
|
||||
# Scoring
|
||||
from .scoring import (
|
||||
BaseScorer,
|
||||
CompositeScorer,
|
||||
PriorityScorer,
|
||||
RecencyScorer,
|
||||
RelevanceScorer,
|
||||
ScoredContext,
|
||||
)
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssembledContext",
|
||||
"AssemblyTimeoutError",
|
||||
"BaseContext",
|
||||
"BaseScorer",
|
||||
"BudgetAllocator",
|
||||
"BudgetExceededError",
|
||||
"CacheError",
|
||||
"ClaudeAdapter",
|
||||
"CompositeScorer",
|
||||
"CompressionError",
|
||||
"ContextCache",
|
||||
"ContextCompressor",
|
||||
"ContextEngine",
|
||||
"ContextError",
|
||||
"ContextNotFoundError",
|
||||
"ContextPipeline",
|
||||
"ContextPriority",
|
||||
"ContextRanker",
|
||||
"ContextSettings",
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"DefaultAdapter",
|
||||
"FormattingError",
|
||||
"InvalidContextError",
|
||||
"KnowledgeContext",
|
||||
"MessageRole",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
"PipelineMetrics",
|
||||
"PriorityScorer",
|
||||
"RankingResult",
|
||||
"RecencyScorer",
|
||||
"RelevanceScorer",
|
||||
"ScoredContext",
|
||||
"ScoringError",
|
||||
"SystemContext",
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
"TokenBudget",
|
||||
"TokenCalculator",
|
||||
"TokenCountError",
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
"create_context_engine",
|
||||
"get_adapter",
|
||||
"get_context_settings",
|
||||
"get_default_settings",
|
||||
"reset_context_settings",
|
||||
]
|
||||
35
backend/app/services/context/adapters/__init__.py
Normal file
35
backend/app/services/context/adapters/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Model Adapters Module.
|
||||
|
||||
Provides model-specific context formatting adapters.
|
||||
"""
|
||||
|
||||
from .base import DefaultAdapter, ModelAdapter
|
||||
from .claude import ClaudeAdapter
|
||||
from .openai import OpenAIAdapter
|
||||
|
||||
|
||||
def get_adapter(model: str) -> ModelAdapter:
|
||||
"""
|
||||
Get the appropriate adapter for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Adapter instance for the model
|
||||
"""
|
||||
if ClaudeAdapter.matches_model(model):
|
||||
return ClaudeAdapter()
|
||||
elif OpenAIAdapter.matches_model(model):
|
||||
return OpenAIAdapter()
|
||||
return DefaultAdapter()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClaudeAdapter",
|
||||
"DefaultAdapter",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
"get_adapter",
|
||||
]
|
||||
178
backend/app/services/context/adapters/base.py
Normal file
178
backend/app/services/context/adapters/base.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Base Model Adapter.
|
||||
|
||||
Abstract base class for model-specific context formatting.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
|
||||
class ModelAdapter(ABC):
|
||||
"""
|
||||
Abstract base adapter for model-specific context formatting.
|
||||
|
||||
Each adapter knows how to format contexts for optimal
|
||||
understanding by a specific LLM family (Claude, OpenAI, etc.).
|
||||
"""
|
||||
|
||||
# Model name patterns this adapter handles
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = []
|
||||
|
||||
@classmethod
|
||||
def matches_model(cls, model: str) -> bool:
|
||||
"""
|
||||
Check if this adapter handles the given model.
|
||||
|
||||
Args:
|
||||
model: Model name to check
|
||||
|
||||
Returns:
|
||||
True if this adapter handles the model
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
return any(pattern in model_lower for pattern in cls.MODEL_PATTERNS)
|
||||
|
||||
@abstractmethod
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to format
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts of a specific type.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts of the same type
|
||||
context_type: The type of contexts
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Formatted string for this context type
|
||||
"""
|
||||
...
|
||||
|
||||
def get_type_order(self) -> list[ContextType]:
|
||||
"""
|
||||
Get the preferred order of context types.
|
||||
|
||||
Returns:
|
||||
List of context types in preferred order
|
||||
"""
|
||||
return [
|
||||
ContextType.SYSTEM,
|
||||
ContextType.TASK,
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.CONVERSATION,
|
||||
ContextType.TOOL,
|
||||
]
|
||||
|
||||
def group_by_type(
|
||||
self, contexts: list[BaseContext]
|
||||
) -> dict[ContextType, list[BaseContext]]:
|
||||
"""
|
||||
Group contexts by their type.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to group
|
||||
|
||||
Returns:
|
||||
Dictionary mapping context type to list of contexts
|
||||
"""
|
||||
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
if ct not in by_type:
|
||||
by_type[ct] = []
|
||||
by_type[ct].append(context)
|
||||
return by_type
|
||||
|
||||
def get_separator(self) -> str:
|
||||
"""
|
||||
Get the separator between context sections.
|
||||
|
||||
Returns:
|
||||
Separator string
|
||||
"""
|
||||
return "\n\n"
|
||||
|
||||
|
||||
class DefaultAdapter(ModelAdapter):
|
||||
"""
|
||||
Default adapter for unknown models.
|
||||
|
||||
Uses simple plain-text formatting with minimal structure.
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
|
||||
|
||||
@classmethod
|
||||
def matches_model(cls, model: str) -> bool:
|
||||
"""Always returns True as fallback."""
|
||||
return True
|
||||
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Format contexts as plain text."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
by_type = self.group_by_type(contexts)
|
||||
parts: list[str] = []
|
||||
|
||||
for ct in self.get_type_order():
|
||||
if ct in by_type:
|
||||
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return self.get_separator().join(parts)
|
||||
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Format contexts of a type as plain text."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return content
|
||||
elif context_type == ContextType.TASK:
|
||||
return f"Task:\n{content}"
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return f"Reference Information:\n{content}"
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return f"Previous Conversation:\n{content}"
|
||||
elif context_type == ContextType.TOOL:
|
||||
return f"Tool Results:\n{content}"
|
||||
|
||||
return content
|
||||
212
backend/app/services/context/adapters/claude.py
Normal file
212
backend/app/services/context/adapters/claude.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Claude Model Adapter.
|
||||
|
||||
Provides Claude-specific context formatting using XML tags
|
||||
which Claude models understand natively.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import ModelAdapter
|
||||
|
||||
|
||||
class ClaudeAdapter(ModelAdapter):
|
||||
"""
|
||||
Claude-specific context formatting adapter.
|
||||
|
||||
Claude models have native understanding of XML structure,
|
||||
so we use XML tags for clear delineation of context types.
|
||||
|
||||
Features:
|
||||
- XML tags for each context type
|
||||
- Document structure for knowledge contexts
|
||||
- Role-based message formatting for conversations
|
||||
- Tool result wrapping with tool names
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
|
||||
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for Claude models.
|
||||
|
||||
Uses XML tags for structured content that Claude
|
||||
understands natively.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to format
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
XML-structured context string
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
by_type = self.group_by_type(contexts)
|
||||
parts: list[str] = []
|
||||
|
||||
for ct in self.get_type_order():
|
||||
if ct in by_type:
|
||||
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return self.get_separator().join(parts)
|
||||
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts of a specific type for Claude.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts of the same type
|
||||
context_type: The type of contexts
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
XML-formatted string for this context type
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts)
|
||||
|
||||
# Fallback for any unhandled context types - still escape content
|
||||
# to prevent XML injection if new types are added without updating adapter
|
||||
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format system contexts."""
|
||||
# System prompts are typically admin-controlled, but escape for safety
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format knowledge contexts as structured documents.
|
||||
|
||||
Each knowledge context becomes a document with source attribution.
|
||||
All content is XML-escaped to prevent injection attacks.
|
||||
"""
|
||||
parts = ["<reference_documents>"]
|
||||
|
||||
for ctx in contexts:
|
||||
source = self._escape_xml(ctx.source)
|
||||
# Escape content to prevent XML injection
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||
|
||||
if score:
|
||||
# Escape score to prevent XML injection via metadata
|
||||
escaped_score = self._escape_xml(str(score))
|
||||
parts.append(
|
||||
f'<document source="{source}" relevance="{escaped_score}">'
|
||||
)
|
||||
else:
|
||||
parts.append(f'<document source="{source}">')
|
||||
|
||||
parts.append(content)
|
||||
parts.append("</document>")
|
||||
|
||||
parts.append("</reference_documents>")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format conversation contexts as message history.
|
||||
|
||||
Uses role-based message tags for clear turn delineation.
|
||||
All content is XML-escaped to prevent prompt injection.
|
||||
"""
|
||||
parts = ["<conversation_history>"]
|
||||
|
||||
for ctx in contexts:
|
||||
role = self._escape_xml(ctx.metadata.get("role", "user"))
|
||||
# Escape content to prevent prompt injection via fake XML tags
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
parts.append(f'<message role="{role}">')
|
||||
parts.append(content)
|
||||
parts.append("</message>")
|
||||
|
||||
parts.append("</conversation_history>")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format tool contexts as tool results.
|
||||
|
||||
Each tool result is wrapped with the tool name.
|
||||
All content is XML-escaped to prevent injection.
|
||||
"""
|
||||
parts = ["<tool_results>"]
|
||||
|
||||
for ctx in contexts:
|
||||
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
|
||||
status = ctx.metadata.get("status", "")
|
||||
|
||||
if status:
|
||||
parts.append(
|
||||
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
|
||||
)
|
||||
else:
|
||||
parts.append(f'<tool_result name="{tool_name}">')
|
||||
|
||||
# Escape content to prevent injection
|
||||
parts.append(self._escape_xml_content(ctx.content))
|
||||
parts.append("</tool_result>")
|
||||
|
||||
parts.append("</tool_results>")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters in attribute values."""
|
||||
return (
|
||||
text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml_content(text: str) -> str:
|
||||
"""
|
||||
Escape XML special characters in element content.
|
||||
|
||||
This prevents XML injection attacks where malicious content
|
||||
could break out of XML tags or inject fake tags for prompt injection.
|
||||
|
||||
Only escapes &, <, > since quotes don't need escaping in content.
|
||||
|
||||
Args:
|
||||
text: Content text to escape
|
||||
|
||||
Returns:
|
||||
XML-safe content string
|
||||
"""
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
160
backend/app/services/context/adapters/openai.py
Normal file
160
backend/app/services/context/adapters/openai.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
OpenAI Model Adapter.
|
||||
|
||||
Provides OpenAI-specific context formatting using markdown
|
||||
which GPT models understand well.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import ModelAdapter
|
||||
|
||||
|
||||
class OpenAIAdapter(ModelAdapter):
|
||||
"""
|
||||
OpenAI-specific context formatting adapter.
|
||||
|
||||
GPT models work well with markdown formatting,
|
||||
so we use headers and structured markdown for clarity.
|
||||
|
||||
Features:
|
||||
- Markdown headers for each context type
|
||||
- Bulleted lists for document sources
|
||||
- Bold role labels for conversations
|
||||
- Code blocks for tool outputs
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
|
||||
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for OpenAI models.
|
||||
|
||||
Uses markdown formatting for structured content.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to format
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Markdown-structured context string
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
by_type = self.group_by_type(contexts)
|
||||
parts: list[str] = []
|
||||
|
||||
for ct in self.get_type_order():
|
||||
if ct in by_type:
|
||||
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return self.get_separator().join(parts)
|
||||
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts of a specific type for OpenAI.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts of the same type
|
||||
context_type: The type of contexts
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Markdown-formatted string for this context type
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts)
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
return content
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
return f"## Current Task\n\n{content}"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format knowledge contexts as structured documents.
|
||||
|
||||
Each knowledge context becomes a section with source attribution.
|
||||
"""
|
||||
parts = ["## Reference Documents\n"]
|
||||
|
||||
for ctx in contexts:
|
||||
source = ctx.source
|
||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||
|
||||
if score:
|
||||
parts.append(f"### Source: {source} (relevance: {score})\n")
|
||||
else:
|
||||
parts.append(f"### Source: {source}\n")
|
||||
|
||||
parts.append(ctx.content)
|
||||
parts.append("")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format conversation contexts as message history.
|
||||
|
||||
Uses bold role labels for clear turn delineation.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user").upper()
|
||||
parts.append(f"**{role}**: {ctx.content}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format tool contexts as tool results.
|
||||
|
||||
Each tool result is in a code block with the tool name.
|
||||
"""
|
||||
parts = ["## Recent Tool Results\n"]
|
||||
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
status = ctx.metadata.get("status", "")
|
||||
|
||||
if status:
|
||||
parts.append(f"### Tool: {tool_name} ({status})\n")
|
||||
else:
|
||||
parts.append(f"### Tool: {tool_name}\n")
|
||||
|
||||
parts.append(f"```\n{ctx.content}\n```")
|
||||
parts.append("")
|
||||
|
||||
return "\n".join(parts)
|
||||
12
backend/app/services/context/assembly/__init__.py
Normal file
12
backend/app/services/context/assembly/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Context Assembly Module.
|
||||
|
||||
Provides the assembly pipeline and formatting.
|
||||
"""
|
||||
|
||||
from .pipeline import ContextPipeline, PipelineMetrics
|
||||
|
||||
__all__ = [
|
||||
"ContextPipeline",
|
||||
"PipelineMetrics",
|
||||
]
|
||||
362
backend/app/services/context/assembly/pipeline.py
Normal file
362
backend/app/services/context/assembly/pipeline.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Context Assembly Pipeline.
|
||||
|
||||
Orchestrates the full context assembly workflow:
|
||||
Gather → Count → Score → Rank → Compress → Format
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..adapters import get_adapter
|
||||
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from ..compression.truncation import ContextCompressor
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import AssemblyTimeoutError
|
||||
from ..prioritization import ContextRanker
|
||||
from ..scoring import CompositeScorer
|
||||
from ..types import AssembledContext, BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineMetrics:
|
||||
"""Metrics from pipeline execution."""
|
||||
|
||||
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
end_time: datetime | None = None
|
||||
total_contexts: int = 0
|
||||
selected_contexts: int = 0
|
||||
excluded_contexts: int = 0
|
||||
compressed_contexts: int = 0
|
||||
total_tokens: int = 0
|
||||
assembly_time_ms: float = 0.0
|
||||
scoring_time_ms: float = 0.0
|
||||
compression_time_ms: float = 0.0
|
||||
formatting_time_ms: float = 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||
"total_contexts": self.total_contexts,
|
||||
"selected_contexts": self.selected_contexts,
|
||||
"excluded_contexts": self.excluded_contexts,
|
||||
"compressed_contexts": self.compressed_contexts,
|
||||
"total_tokens": self.total_tokens,
|
||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||
"scoring_time_ms": round(self.scoring_time_ms, 2),
|
||||
"compression_time_ms": round(self.compression_time_ms, 2),
|
||||
"formatting_time_ms": round(self.formatting_time_ms, 2),
|
||||
}
|
||||
|
||||
|
||||
class ContextPipeline:
|
||||
"""
|
||||
Context assembly pipeline.
|
||||
|
||||
Orchestrates the full workflow of context assembly:
|
||||
1. Validate and count tokens for all contexts
|
||||
2. Score contexts based on relevance, recency, and priority
|
||||
3. Rank and select contexts within budget
|
||||
4. Compress if needed to fit remaining budget
|
||||
5. Format for the target model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
scorer: CompositeScorer | None = None,
|
||||
ranker: ContextRanker | None = None,
|
||||
compressor: ContextCompressor | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context pipeline.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway integration
|
||||
settings: Context settings
|
||||
calculator: Token calculator
|
||||
scorer: Context scorer
|
||||
ranker: Context ranker
|
||||
compressor: Context compressor
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._mcp = mcp_manager
|
||||
|
||||
# Initialize components
|
||||
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
|
||||
self._scorer = scorer or CompositeScorer(
|
||||
mcp_manager=mcp_manager, settings=self._settings
|
||||
)
|
||||
self._ranker = ranker or ContextRanker(
|
||||
scorer=self._scorer, calculator=self._calculator
|
||||
)
|
||||
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for all components."""
|
||||
self._mcp = mcp_manager
|
||||
self._calculator.set_mcp_manager(mcp_manager)
|
||||
self._scorer.set_mcp_manager(mcp_manager)
|
||||
|
||||
async def assemble(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
custom_budget: TokenBudget | None = None,
|
||||
compress: bool = True,
|
||||
format_output: bool = True,
|
||||
timeout_ms: int | None = None,
|
||||
) -> AssembledContext:
|
||||
"""
|
||||
Assemble context for an LLM request.
|
||||
|
||||
This is the main entry point for context assembly.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to assemble
|
||||
query: Query to optimize for
|
||||
model: Target model name
|
||||
max_tokens: Maximum total tokens (uses model default if None)
|
||||
custom_budget: Optional pre-configured budget
|
||||
compress: Whether to compress oversized contexts
|
||||
format_output: Whether to format the final output
|
||||
timeout_ms: Maximum assembly time in milliseconds
|
||||
|
||||
Returns:
|
||||
AssembledContext with optimized content
|
||||
|
||||
Raises:
|
||||
AssemblyTimeoutError: If assembly exceeds timeout
|
||||
"""
|
||||
timeout = timeout_ms or self._settings.max_assembly_time_ms
|
||||
start = time.perf_counter()
|
||||
metrics = PipelineMetrics(total_contexts=len(contexts))
|
||||
|
||||
try:
|
||||
# Create or use budget
|
||||
if custom_budget:
|
||||
budget = custom_budget
|
||||
elif max_tokens:
|
||||
budget = self._allocator.create_budget(max_tokens)
|
||||
else:
|
||||
budget = self._allocator.create_budget_for_model(model)
|
||||
|
||||
# 1. Count tokens for all contexts (with timeout enforcement)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._ensure_token_counts(contexts, model),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during token counting",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
|
||||
# Check timeout (handles edge case where operation finished just at limit)
|
||||
self._check_timeout(start, timeout, "token counting")
|
||||
|
||||
# 2. Score and rank contexts (with timeout enforcement)
|
||||
scoring_start = time.perf_counter()
|
||||
try:
|
||||
ranking_result = await asyncio.wait_for(
|
||||
self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during scoring/ranking",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||
|
||||
selected_contexts = ranking_result.selected_contexts
|
||||
metrics.selected_contexts = len(selected_contexts)
|
||||
metrics.excluded_contexts = len(ranking_result.excluded)
|
||||
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "scoring")
|
||||
|
||||
# 3. Compress if needed and enabled (with timeout enforcement)
|
||||
if compress and self._needs_compression(selected_contexts, budget):
|
||||
compression_start = time.perf_counter()
|
||||
try:
|
||||
selected_contexts = await asyncio.wait_for(
|
||||
self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during compression",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.compression_time_ms = (
|
||||
time.perf_counter() - compression_start
|
||||
) * 1000
|
||||
metrics.compressed_contexts = sum(
|
||||
1 for c in selected_contexts if c.metadata.get("truncated", False)
|
||||
)
|
||||
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "compression")
|
||||
|
||||
# 4. Format output
|
||||
formatting_start = time.perf_counter()
|
||||
if format_output:
|
||||
formatted_content = self._format_contexts(selected_contexts, model)
|
||||
else:
|
||||
formatted_content = "\n\n".join(c.content for c in selected_contexts)
|
||||
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
|
||||
|
||||
# Calculate final metrics
|
||||
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
|
||||
metrics.total_tokens = total_tokens
|
||||
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
|
||||
metrics.end_time = datetime.now(UTC)
|
||||
|
||||
return AssembledContext(
|
||||
content=formatted_content,
|
||||
total_tokens=total_tokens,
|
||||
context_count=len(selected_contexts),
|
||||
assembly_time_ms=metrics.assembly_time_ms,
|
||||
model=model,
|
||||
contexts=selected_contexts,
|
||||
excluded_count=metrics.excluded_contexts,
|
||||
metadata={
|
||||
"metrics": metrics.to_dict(),
|
||||
"query": query,
|
||||
"budget": budget.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
except AssemblyTimeoutError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Context assembly failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""Ensure all contexts have token counts."""
|
||||
tasks = []
|
||||
for context in contexts:
|
||||
if context.token_count is None:
|
||||
tasks.append(self._count_and_set(context, model))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _count_and_set(
|
||||
self,
|
||||
context: BaseContext,
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""Count tokens and set on context."""
|
||||
count = await self._calculator.count_tokens(context.content, model)
|
||||
context.token_count = count
|
||||
|
||||
def _needs_compression(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
budget: TokenBudget,
|
||||
) -> bool:
|
||||
"""Check if any contexts exceed their type budget."""
|
||||
# Group by type and check totals
|
||||
by_type: dict[ContextType, int] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
|
||||
|
||||
for ct, total in by_type.items():
|
||||
if total > budget.get_allocation(ct):
|
||||
return True
|
||||
|
||||
# Also check if utilization exceeds threshold
|
||||
return budget.utilization() > self._settings.compression_threshold
|
||||
|
||||
def _format_contexts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
||||
to format contexts optimally for each model family.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to format
|
||||
model: Target model name
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
adapter = get_adapter(model)
|
||||
return adapter.format(contexts)
|
||||
|
||||
def _check_timeout(
|
||||
self,
|
||||
start: float,
|
||||
timeout_ms: int,
|
||||
phase: str,
|
||||
) -> None:
|
||||
"""Check if timeout exceeded and raise if so."""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
if elapsed_ms >= timeout_ms:
|
||||
raise AssemblyTimeoutError(
|
||||
message=f"Context assembly timed out during {phase}",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
||||
"""
|
||||
Calculate remaining timeout in seconds for asyncio.wait_for.
|
||||
|
||||
Returns at least a small positive value to avoid immediate timeout
|
||||
edge cases with wait_for.
|
||||
|
||||
Args:
|
||||
start: Start time from time.perf_counter()
|
||||
timeout_ms: Total timeout in milliseconds
|
||||
|
||||
Returns:
|
||||
Remaining timeout in seconds (minimum 0.001)
|
||||
"""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
remaining_ms = timeout_ms - elapsed_ms
|
||||
# Return at least 1ms to avoid zero/negative timeout edge cases
|
||||
return max(remaining_ms / 1000.0, 0.001)
|
||||
14
backend/app/services/context/budget/__init__.py
Normal file
14
backend/app/services/context/budget/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Token Budget Management Module.
|
||||
|
||||
Provides token counting and budget allocation.
|
||||
"""
|
||||
|
||||
from .allocator import BudgetAllocator, TokenBudget
|
||||
from .calculator import TokenCalculator
|
||||
|
||||
__all__ = [
|
||||
"BudgetAllocator",
|
||||
"TokenBudget",
|
||||
"TokenCalculator",
|
||||
]
|
||||
433
backend/app/services/context/budget/allocator.py
Normal file
433
backend/app/services/context/budget/allocator.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Token Budget Allocator for Context Management.
|
||||
|
||||
Manages token budget allocation across context types.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..types import ContextType
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBudget:
|
||||
"""
|
||||
Token budget allocation and tracking.
|
||||
|
||||
Tracks allocated tokens per context type and
|
||||
monitors usage to prevent overflows.
|
||||
"""
|
||||
|
||||
# Total budget
|
||||
total: int
|
||||
|
||||
# Allocated per type
|
||||
system: int = 0
|
||||
task: int = 0
|
||||
knowledge: int = 0
|
||||
conversation: int = 0
|
||||
tools: int = 0
|
||||
response_reserve: int = 0
|
||||
buffer: int = 0
|
||||
|
||||
# Usage tracking
|
||||
used: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize usage tracking."""
|
||||
if not self.used:
|
||||
self.used = {ct.value: 0 for ct in ContextType}
|
||||
|
||||
def get_allocation(self, context_type: ContextType | str) -> int:
|
||||
"""
|
||||
Get allocated tokens for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to get allocation for
|
||||
|
||||
Returns:
|
||||
Allocated token count
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
allocation_map = {
|
||||
"system": self.system,
|
||||
"task": self.task,
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tool": self.tools,
|
||||
}
|
||||
return allocation_map.get(context_type, 0)
|
||||
|
||||
def get_used(self, context_type: ContextType | str) -> int:
|
||||
"""
|
||||
Get used tokens for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to check
|
||||
|
||||
Returns:
|
||||
Used token count
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
return self.used.get(context_type, 0)
|
||||
|
||||
def remaining(self, context_type: ContextType | str) -> int:
|
||||
"""
|
||||
Get remaining tokens for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to check
|
||||
|
||||
Returns:
|
||||
Remaining token count
|
||||
"""
|
||||
allocated = self.get_allocation(context_type)
|
||||
used = self.get_used(context_type)
|
||||
return max(0, allocated - used)
|
||||
|
||||
def total_remaining(self) -> int:
|
||||
"""
|
||||
Get total remaining tokens across all types.
|
||||
|
||||
Returns:
|
||||
Total remaining tokens
|
||||
"""
|
||||
total_used = sum(self.used.values())
|
||||
usable = self.total - self.response_reserve - self.buffer
|
||||
return max(0, usable - total_used)
|
||||
|
||||
def total_used(self) -> int:
|
||||
"""
|
||||
Get total used tokens.
|
||||
|
||||
Returns:
|
||||
Total used tokens
|
||||
"""
|
||||
return sum(self.used.values())
|
||||
|
||||
def can_fit(self, context_type: ContextType | str, tokens: int) -> bool:
|
||||
"""
|
||||
Check if tokens fit within budget for a type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to check
|
||||
tokens: Number of tokens to fit
|
||||
|
||||
Returns:
|
||||
True if tokens fit within remaining budget
|
||||
"""
|
||||
return tokens <= self.remaining(context_type)
|
||||
|
||||
def allocate(
|
||||
self,
|
||||
context_type: ContextType | str,
|
||||
tokens: int,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Allocate (use) tokens from a context type's budget.
|
||||
|
||||
Args:
|
||||
context_type: Context type to allocate from
|
||||
tokens: Number of tokens to allocate
|
||||
force: If True, allow exceeding budget
|
||||
|
||||
Returns:
|
||||
True if allocation succeeded
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If tokens exceed budget and force=False
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
if not force and not self.can_fit(context_type, tokens):
|
||||
raise BudgetExceededError(
|
||||
message=f"Token budget exceeded for {context_type}",
|
||||
allocated=self.get_allocation(context_type),
|
||||
requested=self.get_used(context_type) + tokens,
|
||||
context_type=context_type,
|
||||
)
|
||||
|
||||
self.used[context_type] = self.used.get(context_type, 0) + tokens
|
||||
return True
|
||||
|
||||
def deallocate(
|
||||
self,
|
||||
context_type: ContextType | str,
|
||||
tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Deallocate (return) tokens to a context type's budget.
|
||||
|
||||
Args:
|
||||
context_type: Context type to return to
|
||||
tokens: Number of tokens to return
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
current = self.used.get(context_type, 0)
|
||||
self.used[context_type] = max(0, current - tokens)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all usage tracking."""
|
||||
self.used = {ct.value: 0 for ct in ContextType}
|
||||
|
||||
def utilization(self, context_type: ContextType | str | None = None) -> float:
|
||||
"""
|
||||
Get budget utilization percentage.
|
||||
|
||||
Args:
|
||||
context_type: Specific type or None for total
|
||||
|
||||
Returns:
|
||||
Utilization as a fraction (0.0 to 1.0+)
|
||||
"""
|
||||
if context_type is None:
|
||||
usable = self.total - self.response_reserve - self.buffer
|
||||
if usable <= 0:
|
||||
return 0.0
|
||||
return self.total_used() / usable
|
||||
|
||||
allocated = self.get_allocation(context_type)
|
||||
if allocated <= 0:
|
||||
return 0.0
|
||||
return self.get_used(context_type) / allocated
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert budget to dictionary."""
|
||||
return {
|
||||
"total": self.total,
|
||||
"allocations": {
|
||||
"system": self.system,
|
||||
"task": self.task,
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tools": self.tools,
|
||||
"response_reserve": self.response_reserve,
|
||||
"buffer": self.buffer,
|
||||
},
|
||||
"used": dict(self.used),
|
||||
"remaining": {ct.value: self.remaining(ct) for ct in ContextType},
|
||||
"total_used": self.total_used(),
|
||||
"total_remaining": self.total_remaining(),
|
||||
"utilization": round(self.utilization(), 3),
|
||||
}
|
||||
|
||||
|
||||
class BudgetAllocator:
|
||||
"""
|
||||
Budget allocator for context management.
|
||||
|
||||
Creates token budgets based on configuration and
|
||||
model context window sizes.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: ContextSettings | None = None) -> None:
|
||||
"""
|
||||
Initialize budget allocator.
|
||||
|
||||
Args:
|
||||
settings: Context settings (uses default if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
|
||||
def create_budget(
|
||||
self,
|
||||
total_tokens: int,
|
||||
custom_allocations: dict[str, float] | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Create a token budget with allocations.
|
||||
|
||||
Args:
|
||||
total_tokens: Total available tokens
|
||||
custom_allocations: Optional custom allocation percentages
|
||||
|
||||
Returns:
|
||||
TokenBudget with allocations set
|
||||
"""
|
||||
# Use custom or default allocations
|
||||
if custom_allocations:
|
||||
alloc = custom_allocations
|
||||
else:
|
||||
alloc = self._settings.get_budget_allocation()
|
||||
|
||||
return TokenBudget(
|
||||
total=total_tokens,
|
||||
system=int(total_tokens * alloc.get("system", 0.05)),
|
||||
task=int(total_tokens * alloc.get("task", 0.10)),
|
||||
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
|
||||
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
|
||||
tools=int(total_tokens * alloc.get("tools", 0.05)),
|
||||
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
|
||||
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
|
||||
)
|
||||
|
||||
def adjust_budget(
|
||||
self,
|
||||
budget: TokenBudget,
|
||||
context_type: ContextType | str,
|
||||
adjustment: int,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Adjust a specific allocation in a budget.
|
||||
|
||||
Takes tokens from buffer and adds to specified type.
|
||||
|
||||
Args:
|
||||
budget: Budget to adjust
|
||||
context_type: Type to adjust
|
||||
adjustment: Positive to increase, negative to decrease
|
||||
|
||||
Returns:
|
||||
Adjusted budget
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
|
||||
if adjustment > 0:
|
||||
# Taking from buffer - limited by available buffer
|
||||
actual_adjustment = min(adjustment, budget.buffer)
|
||||
budget.buffer -= actual_adjustment
|
||||
else:
|
||||
# Returning to buffer - limited by current allocation of target type
|
||||
current_allocation = budget.get_allocation(context_type)
|
||||
# Can't return more than current allocation
|
||||
actual_adjustment = max(adjustment, -current_allocation)
|
||||
# Add returned tokens back to buffer (adjustment is negative, so subtract)
|
||||
budget.buffer -= actual_adjustment
|
||||
|
||||
# Apply to target type
|
||||
if context_type == "system":
|
||||
budget.system = max(0, budget.system + actual_adjustment)
|
||||
elif context_type == "task":
|
||||
budget.task = max(0, budget.task + actual_adjustment)
|
||||
elif context_type == "knowledge":
|
||||
budget.knowledge = max(0, budget.knowledge + actual_adjustment)
|
||||
elif context_type == "conversation":
|
||||
budget.conversation = max(0, budget.conversation + actual_adjustment)
|
||||
elif context_type == "tool":
|
||||
budget.tools = max(0, budget.tools + actual_adjustment)
|
||||
|
||||
return budget
|
||||
|
||||
def rebalance_budget(
|
||||
self,
|
||||
budget: TokenBudget,
|
||||
prioritize: list[ContextType] | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Rebalance budget based on actual usage.
|
||||
|
||||
Moves unused allocations to prioritized types.
|
||||
|
||||
Args:
|
||||
budget: Budget to rebalance
|
||||
prioritize: Types to prioritize (in order)
|
||||
|
||||
Returns:
|
||||
Rebalanced budget
|
||||
"""
|
||||
if prioritize is None:
|
||||
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
|
||||
|
||||
# Calculate unused tokens per type
|
||||
unused: dict[str, int] = {}
|
||||
for ct in ContextType:
|
||||
remaining = budget.remaining(ct)
|
||||
if remaining > 0:
|
||||
unused[ct.value] = remaining
|
||||
|
||||
# Calculate total reclaimable (excluding prioritized types)
|
||||
prioritize_values = {ct.value for ct in prioritize}
|
||||
reclaimable = sum(
|
||||
tokens for ct, tokens in unused.items() if ct not in prioritize_values
|
||||
)
|
||||
|
||||
# Redistribute to prioritized types that are near capacity
|
||||
for ct in prioritize:
|
||||
utilization = budget.utilization(ct)
|
||||
|
||||
if utilization > 0.8: # Near capacity
|
||||
# Give more tokens from reclaimable pool
|
||||
bonus = min(reclaimable, budget.get_allocation(ct) // 2)
|
||||
self.adjust_budget(budget, ct, bonus)
|
||||
reclaimable -= bonus
|
||||
|
||||
if reclaimable <= 0:
|
||||
break
|
||||
|
||||
return budget
|
||||
|
||||
def get_model_context_size(self, model: str) -> int:
|
||||
"""
|
||||
Get context window size for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Context window size in tokens
|
||||
"""
|
||||
# Common model context sizes
|
||||
context_sizes = {
|
||||
"claude-3-opus": 200000,
|
||||
"claude-3-sonnet": 200000,
|
||||
"claude-3-haiku": 200000,
|
||||
"claude-3-5-sonnet": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
"claude-opus-4": 200000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 16385,
|
||||
"gemini-1.5-pro": 2000000,
|
||||
"gemini-1.5-flash": 1000000,
|
||||
"gemini-2.0-flash": 1000000,
|
||||
"qwen-plus": 32000,
|
||||
"qwen-turbo": 8000,
|
||||
"deepseek-chat": 64000,
|
||||
"deepseek-reasoner": 64000,
|
||||
}
|
||||
|
||||
# Check exact match first
|
||||
model_lower = model.lower()
|
||||
if model_lower in context_sizes:
|
||||
return context_sizes[model_lower]
|
||||
|
||||
# Check prefix match
|
||||
for model_name, size in context_sizes.items():
|
||||
if model_lower.startswith(model_name):
|
||||
return size
|
||||
|
||||
# Default fallback
|
||||
return 8192
|
||||
|
||||
def create_budget_for_model(
|
||||
self,
|
||||
model: str,
|
||||
custom_allocations: dict[str, float] | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Create a budget based on model's context window.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
custom_allocations: Optional custom allocation percentages
|
||||
|
||||
Returns:
|
||||
TokenBudget sized for the model
|
||||
"""
|
||||
context_size = self.get_model_context_size(model)
|
||||
return self.create_budget(context_size, custom_allocations)
|
||||
285
backend/app/services/context/budget/calculator.py
Normal file
285
backend/app/services/context/budget/calculator.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Token Calculator for Context Management.
|
||||
|
||||
Provides token counting with caching and fallback estimation.
|
||||
Integrates with LLM Gateway for accurate counts.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounterProtocol(Protocol):
|
||||
"""Protocol for token counting implementations."""
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""Count tokens in text."""
|
||||
...
|
||||
|
||||
|
||||
class TokenCalculator:
|
||||
"""
|
||||
Token calculator with LLM Gateway integration.
|
||||
|
||||
Features:
|
||||
- In-memory caching for repeated text
|
||||
- Fallback to character-based estimation
|
||||
- Model-specific counting when possible
|
||||
|
||||
The calculator uses the LLM Gateway's count_tokens tool
|
||||
for accurate counting, with a local cache to avoid
|
||||
repeated calls for the same content.
|
||||
"""
|
||||
|
||||
# Default characters per token ratio for estimation
|
||||
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
|
||||
|
||||
# Model-specific ratios (more accurate estimation)
|
||||
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
"gemini": 4.0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
project_id: str = "system",
|
||||
agent_id: str = "context-engine",
|
||||
cache_enabled: bool = True,
|
||||
cache_max_size: int = 10000,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize token calculator.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway calls
|
||||
project_id: Project ID for LLM Gateway calls
|
||||
agent_id: Agent ID for LLM Gateway calls
|
||||
cache_enabled: Whether to enable in-memory caching
|
||||
cache_max_size: Maximum cache entries
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._project_id = project_id
|
||||
self._agent_id = agent_id
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache_max_size = cache_max_size
|
||||
|
||||
# In-memory cache: hash(model:text) -> token_count
|
||||
self._cache: dict[str, int] = {}
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
def _get_cache_key(self, text: str, model: str | None) -> str:
|
||||
"""Generate cache key from text and model."""
|
||||
# Use hash for efficient storage
|
||||
content = f"{model or 'default'}:{text}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def _check_cache(self, cache_key: str) -> int | None:
|
||||
"""Check cache for existing count."""
|
||||
if not self._cache_enabled:
|
||||
return None
|
||||
|
||||
if cache_key in self._cache:
|
||||
self._cache_hits += 1
|
||||
return self._cache[cache_key]
|
||||
|
||||
self._cache_misses += 1
|
||||
return None
|
||||
|
||||
def _store_cache(self, cache_key: str, count: int) -> None:
|
||||
"""Store count in cache."""
|
||||
if not self._cache_enabled:
|
||||
return
|
||||
|
||||
# Simple LRU-like eviction: remove oldest entries when full
|
||||
if len(self._cache) >= self._cache_max_size:
|
||||
# Remove first 10% of entries
|
||||
entries_to_remove = self._cache_max_size // 10
|
||||
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._cache[cache_key] = count
|
||||
|
||||
def estimate_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""
|
||||
Estimate token count based on character count.
|
||||
|
||||
This is a fast fallback when LLM Gateway is unavailable.
|
||||
|
||||
Args:
|
||||
text: Text to count
|
||||
model: Optional model for more accurate ratio
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Get model-specific ratio
|
||||
ratio = self.DEFAULT_CHARS_PER_TOKEN
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
|
||||
if model_prefix in model_lower:
|
||||
ratio = model_ratio
|
||||
break
|
||||
|
||||
return max(1, int(len(text) / ratio))
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens in text.
|
||||
|
||||
Uses LLM Gateway for accurate counts with fallback to estimation.
|
||||
|
||||
Args:
|
||||
text: Text to count
|
||||
model: Optional model for accurate counting
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Check cache first
|
||||
cache_key = self._get_cache_key(text, model)
|
||||
cached = self._check_cache(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# Try LLM Gateway
|
||||
if self._mcp is not None:
|
||||
try:
|
||||
result = await self._mcp.call_tool(
|
||||
server="llm-gateway",
|
||||
tool="count_tokens",
|
||||
args={
|
||||
"project_id": self._project_id,
|
||||
"agent_id": self._agent_id,
|
||||
"text": text,
|
||||
"model": model,
|
||||
},
|
||||
)
|
||||
|
||||
# Parse result
|
||||
if result.success and result.data:
|
||||
count = self._parse_token_count(result.data)
|
||||
if count is not None:
|
||||
self._store_cache(cache_key, count)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
|
||||
|
||||
# Fallback to estimation
|
||||
count = self.estimate_tokens(text, model)
|
||||
self._store_cache(cache_key, count)
|
||||
return count
|
||||
|
||||
def _parse_token_count(self, data: Any) -> int | None:
|
||||
"""Parse token count from LLM Gateway response."""
|
||||
if isinstance(data, dict):
|
||||
if "token_count" in data:
|
||||
return int(data["token_count"])
|
||||
if "tokens" in data:
|
||||
return int(data["tokens"])
|
||||
if "count" in data:
|
||||
return int(data["count"])
|
||||
|
||||
if isinstance(data, int):
|
||||
return data
|
||||
|
||||
if isinstance(data, str):
|
||||
# Try to parse from text content
|
||||
try:
|
||||
# Handle {"token_count": 123} or just "123"
|
||||
import json
|
||||
|
||||
parsed = json.loads(data)
|
||||
if isinstance(parsed, dict) and "token_count" in parsed:
|
||||
return int(parsed["token_count"])
|
||||
if isinstance(parsed, int):
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Try direct int conversion
|
||||
try:
|
||||
return int(data)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def count_tokens_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
model: str | None = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Count tokens for multiple texts.
|
||||
|
||||
Efficient batch counting with caching and parallel execution.
|
||||
|
||||
Args:
|
||||
texts: List of texts to count
|
||||
model: Optional model for accurate counting
|
||||
|
||||
Returns:
|
||||
List of token counts (same order as input)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Execute all token counts in parallel for better performance
|
||||
tasks = [self.count_tokens(text, model) for text in texts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the token count cache."""
|
||||
self._cache.clear()
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
total = self._cache_hits + self._cache_misses
|
||||
hit_rate = self._cache_hits / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"enabled": self._cache_enabled,
|
||||
"size": len(self._cache),
|
||||
"max_size": self._cache_max_size,
|
||||
"hits": self._cache_hits,
|
||||
"misses": self._cache_misses,
|
||||
"hit_rate": round(hit_rate, 3),
|
||||
}
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""
|
||||
Set the MCP manager (for lazy initialization).
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager instance
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Context Cache Module.
|
||||
|
||||
Provides Redis-based caching for assembled contexts.
|
||||
"""
|
||||
|
||||
from .context_cache import ContextCache
|
||||
|
||||
__all__ = [
|
||||
"ContextCache",
|
||||
]
|
||||
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Context Cache Implementation.
|
||||
|
||||
Provides Redis-based caching for context operations including
|
||||
assembled contexts, token counts, and scoring results.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import CacheError
|
||||
from ..types import AssembledContext, BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextCache:
|
||||
"""
|
||||
Redis-based caching for context operations.
|
||||
|
||||
Provides caching for:
|
||||
- Assembled contexts (fingerprint-based)
|
||||
- Token counts (content hash-based)
|
||||
- Scoring results (context + query hash-based)
|
||||
|
||||
Cache keys use a hierarchical structure:
|
||||
- ctx:assembled:{fingerprint}
|
||||
- ctx:tokens:{model}:{content_hash}
|
||||
- ctx:score:{scorer}:{context_hash}:{query_hash}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context cache.
|
||||
|
||||
Args:
|
||||
redis: Redis connection (optional for testing)
|
||||
settings: Cache settings
|
||||
"""
|
||||
self._redis = redis
|
||||
self._settings = settings or get_context_settings()
|
||||
self._prefix = self._settings.cache_prefix
|
||||
self._ttl = self._settings.cache_ttl_seconds
|
||||
|
||||
# In-memory fallback cache when Redis unavailable
|
||||
self._memory_cache: dict[str, tuple[str, float]] = {}
|
||||
self._max_memory_items = self._settings.cache_memory_max_items
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection."""
|
||||
self._redis = redis
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if caching is enabled and available."""
|
||||
return self._settings.cache_enabled and self._redis is not None
|
||||
|
||||
def _cache_key(self, *parts: str) -> str:
|
||||
"""
|
||||
Build a cache key from parts.
|
||||
|
||||
Args:
|
||||
*parts: Key components
|
||||
|
||||
Returns:
|
||||
Colon-separated cache key
|
||||
"""
|
||||
return f"{self._prefix}:{':'.join(parts)}"
|
||||
|
||||
@staticmethod
|
||||
def _hash_content(content: str) -> str:
|
||||
"""
|
||||
Compute hash of content for cache key.
|
||||
|
||||
Args:
|
||||
content: Content to hash
|
||||
|
||||
Returns:
|
||||
32-character hex hash
|
||||
"""
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def compute_fingerprint(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
model: str,
|
||||
project_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compute a fingerprint for a context assembly request.
|
||||
|
||||
The fingerprint is based on:
|
||||
- Project and agent IDs (for tenant isolation)
|
||||
- Context content hash and metadata (not full content for performance)
|
||||
- Query string
|
||||
- Target model
|
||||
|
||||
SECURITY: project_id and agent_id MUST be included to prevent
|
||||
cross-tenant cache pollution. Without these, one tenant could
|
||||
receive cached contexts from another tenant with the same query.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts
|
||||
query: Query string
|
||||
model: Model name
|
||||
project_id: Project ID for tenant isolation
|
||||
agent_id: Agent ID for tenant isolation
|
||||
|
||||
Returns:
|
||||
32-character hex fingerprint
|
||||
"""
|
||||
# Build a deterministic representation using content hashes for performance
|
||||
# This avoids JSON serializing potentially large content strings
|
||||
context_data = []
|
||||
for ctx in contexts:
|
||||
context_data.append(
|
||||
{
|
||||
"type": ctx.get_type().value,
|
||||
"content_hash": self._hash_content(
|
||||
ctx.content
|
||||
), # Hash instead of full content
|
||||
"source": ctx.source,
|
||||
"priority": ctx.priority, # Already an int
|
||||
}
|
||||
)
|
||||
|
||||
data = {
|
||||
# CRITICAL: Include tenant identifiers for cache isolation
|
||||
"project_id": project_id or "",
|
||||
"agent_id": agent_id or "",
|
||||
"contexts": context_data,
|
||||
"query": query,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
content = json.dumps(data, sort_keys=True)
|
||||
return self._hash_content(content)
|
||||
|
||||
async def get_assembled(
|
||||
self,
|
||||
fingerprint: str,
|
||||
) -> AssembledContext | None:
|
||||
"""
|
||||
Get cached assembled context by fingerprint.
|
||||
|
||||
Args:
|
||||
fingerprint: Assembly fingerprint
|
||||
|
||||
Returns:
|
||||
Cached AssembledContext or None if not found
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
key = self._cache_key("assembled", fingerprint)
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key) # type: ignore
|
||||
if data:
|
||||
logger.debug(f"Cache hit for assembled context: {fingerprint}")
|
||||
result = AssembledContext.from_json(data)
|
||||
result.cache_hit = True
|
||||
result.cache_key = fingerprint
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error: {e}")
|
||||
raise CacheError(f"Failed to get assembled context: {e}") from e
|
||||
|
||||
return None
|
||||
|
||||
async def set_assembled(
|
||||
self,
|
||||
fingerprint: str,
|
||||
context: AssembledContext,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache an assembled context.
|
||||
|
||||
Args:
|
||||
fingerprint: Assembly fingerprint
|
||||
context: Assembled context to cache
|
||||
ttl: Optional TTL override in seconds
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return
|
||||
|
||||
key = self._cache_key("assembled", fingerprint)
|
||||
expire = ttl or self._ttl
|
||||
|
||||
try:
|
||||
await self._redis.setex(key, expire, context.to_json()) # type: ignore
|
||||
logger.debug(f"Cached assembled context: {fingerprint}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error: {e}")
|
||||
raise CacheError(f"Failed to cache assembled context: {e}") from e
|
||||
|
||||
async def get_token_count(
|
||||
self,
|
||||
content: str,
|
||||
model: str | None = None,
|
||||
) -> int | None:
|
||||
"""
|
||||
Get cached token count.
|
||||
|
||||
Args:
|
||||
content: Content to look up
|
||||
model: Model name for model-specific tokenization
|
||||
|
||||
Returns:
|
||||
Cached token count or None if not found
|
||||
"""
|
||||
model_key = model or "default"
|
||||
content_hash = self._hash_content(content)
|
||||
key = self._cache_key("tokens", model_key, content_hash)
|
||||
|
||||
# Try in-memory first
|
||||
if key in self._memory_cache:
|
||||
return int(self._memory_cache[key][0])
|
||||
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key) # type: ignore
|
||||
if data:
|
||||
count = int(data)
|
||||
# Store in memory for faster subsequent access
|
||||
self._set_memory(key, str(count))
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error for tokens: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def set_token_count(
|
||||
self,
|
||||
content: str,
|
||||
count: int,
|
||||
model: str | None = None,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a token count.
|
||||
|
||||
Args:
|
||||
content: Content that was counted
|
||||
count: Token count
|
||||
model: Model name
|
||||
ttl: Optional TTL override in seconds
|
||||
"""
|
||||
model_key = model or "default"
|
||||
content_hash = self._hash_content(content)
|
||||
key = self._cache_key("tokens", model_key, content_hash)
|
||||
expire = ttl or self._ttl
|
||||
|
||||
# Always store in memory
|
||||
self._set_memory(key, str(count))
|
||||
|
||||
if not self.is_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._redis.setex(key, expire, str(count)) # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error for tokens: {e}")
|
||||
|
||||
async def get_score(
|
||||
self,
|
||||
scorer_name: str,
|
||||
context_id: str,
|
||||
query: str,
|
||||
) -> float | None:
|
||||
"""
|
||||
Get cached score.
|
||||
|
||||
Args:
|
||||
scorer_name: Name of the scorer
|
||||
context_id: Context identifier
|
||||
query: Query string
|
||||
|
||||
Returns:
|
||||
Cached score or None if not found
|
||||
"""
|
||||
query_hash = self._hash_content(query)[:16]
|
||||
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||
|
||||
# Try in-memory first
|
||||
if key in self._memory_cache:
|
||||
return float(self._memory_cache[key][0])
|
||||
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key) # type: ignore
|
||||
if data:
|
||||
score = float(data)
|
||||
self._set_memory(key, str(score))
|
||||
return score
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error for score: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def set_score(
|
||||
self,
|
||||
scorer_name: str,
|
||||
context_id: str,
|
||||
query: str,
|
||||
score: float,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a score.
|
||||
|
||||
Args:
|
||||
scorer_name: Name of the scorer
|
||||
context_id: Context identifier
|
||||
query: Query string
|
||||
score: Score value
|
||||
ttl: Optional TTL override in seconds
|
||||
"""
|
||||
query_hash = self._hash_content(query)[:16]
|
||||
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||
expire = ttl or self._ttl
|
||||
|
||||
# Always store in memory
|
||||
self._set_memory(key, str(score))
|
||||
|
||||
if not self.is_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._redis.setex(key, expire, str(score)) # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error for score: {e}")
|
||||
|
||||
async def invalidate(self, pattern: str) -> int:
|
||||
"""
|
||||
Invalidate cache entries matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Key pattern (supports * wildcard)
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return 0
|
||||
|
||||
full_pattern = self._cache_key(pattern)
|
||||
deleted = 0
|
||||
|
||||
try:
|
||||
async for key in self._redis.scan_iter(match=full_pattern): # type: ignore
|
||||
await self._redis.delete(key) # type: ignore
|
||||
deleted += 1
|
||||
|
||||
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache invalidation error: {e}")
|
||||
raise CacheError(f"Failed to invalidate cache: {e}") from e
|
||||
|
||||
return deleted
|
||||
|
||||
async def clear_all(self) -> int:
|
||||
"""
|
||||
Clear all context cache entries.
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
self._memory_cache.clear()
|
||||
return await self.invalidate("*")
|
||||
|
||||
def _set_memory(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set a value in the memory cache.
|
||||
|
||||
Uses LRU-style eviction when max items reached.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to store
|
||||
"""
|
||||
import time
|
||||
|
||||
if len(self._memory_cache) >= self._max_memory_items:
|
||||
# Evict oldest entries
|
||||
sorted_keys = sorted(
|
||||
self._memory_cache.keys(),
|
||||
key=lambda k: self._memory_cache[k][1],
|
||||
)
|
||||
for k in sorted_keys[: len(sorted_keys) // 2]:
|
||||
del self._memory_cache[k]
|
||||
|
||||
self._memory_cache[key] = (value, time.time())
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache stats
|
||||
"""
|
||||
stats = {
|
||||
"enabled": self._settings.cache_enabled,
|
||||
"redis_available": self._redis is not None,
|
||||
"memory_items": len(self._memory_cache),
|
||||
"ttl_seconds": self._ttl,
|
||||
}
|
||||
|
||||
if self.is_enabled:
|
||||
try:
|
||||
# Get Redis info
|
||||
info = await self._redis.info("memory") # type: ignore
|
||||
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get Redis stats: {e}")
|
||||
|
||||
return stats
|
||||
13
backend/app/services/context/compression/__init__.py
Normal file
13
backend/app/services/context/compression/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Context Compression Module.
|
||||
|
||||
Provides truncation and compression strategies.
|
||||
"""
|
||||
|
||||
from .truncation import ContextCompressor, TruncationResult, TruncationStrategy
|
||||
|
||||
__all__ = [
|
||||
"ContextCompressor",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
]
|
||||
453
backend/app/services/context/compression/truncation.py
Normal file
453
backend/app/services/context/compression/truncation.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Smart Truncation for Context Compression.
|
||||
|
||||
Provides intelligent truncation strategies to reduce context size
|
||||
while preserving the most important information.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _estimate_tokens(text: str, model: str | None = None) -> int:
|
||||
"""
|
||||
Estimate token count using model-specific character ratios.
|
||||
|
||||
Module-level function for reuse across classes. Uses the same ratios
|
||||
as TokenCalculator for consistency.
|
||||
|
||||
Args:
|
||||
text: Text to estimate tokens for
|
||||
model: Optional model name for model-specific ratios
|
||||
|
||||
Returns:
|
||||
Estimated token count (minimum 1)
|
||||
"""
|
||||
# Model-specific character ratios (chars per token)
|
||||
model_ratios = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
"gemini": 4.0,
|
||||
}
|
||||
default_ratio = 4.0
|
||||
|
||||
ratio = default_ratio
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for model_prefix, model_ratio in model_ratios.items():
|
||||
if model_prefix in model_lower:
|
||||
ratio = model_ratio
|
||||
break
|
||||
|
||||
return max(1, int(len(text) / ratio))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationResult:
|
||||
"""Result of truncation operation."""
|
||||
|
||||
original_tokens: int
|
||||
truncated_tokens: int
|
||||
content: str
|
||||
truncated: bool
|
||||
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
|
||||
|
||||
@property
|
||||
def tokens_saved(self) -> int:
|
||||
"""Calculate tokens saved by truncation."""
|
||||
return self.original_tokens - self.truncated_tokens
|
||||
|
||||
|
||||
class TruncationStrategy:
|
||||
"""
|
||||
Smart truncation strategies for context compression.
|
||||
|
||||
Strategies:
|
||||
1. End truncation: Cut from end (for knowledge/docs)
|
||||
2. Middle truncation: Keep start and end (for code)
|
||||
3. Sentence-aware: Truncate at sentence boundaries
|
||||
4. Semantic chunking: Keep most relevant chunks
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
preserve_ratio_start: float | None = None,
|
||||
min_content_length: int | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize truncation strategy.
|
||||
|
||||
Args:
|
||||
calculator: Token calculator for accurate counting
|
||||
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
|
||||
min_content_length: Minimum characters to preserve (overrides settings)
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._calculator = calculator
|
||||
|
||||
# Use provided values or fall back to settings
|
||||
self._preserve_ratio_start = (
|
||||
preserve_ratio_start
|
||||
if preserve_ratio_start is not None
|
||||
else self._settings.truncation_preserve_ratio
|
||||
)
|
||||
self._min_content_length = (
|
||||
min_content_length
|
||||
if min_content_length is not None
|
||||
else self._settings.truncation_min_content_length
|
||||
)
|
||||
|
||||
@property
|
||||
def truncation_marker(self) -> str:
|
||||
"""Get truncation marker from settings."""
|
||||
return self._settings.truncation_marker
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def truncate_to_tokens(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
strategy: str = "end",
|
||||
model: str | None = None,
|
||||
) -> TruncationResult:
|
||||
"""
|
||||
Truncate content to fit within token limit.
|
||||
|
||||
Args:
|
||||
content: Content to truncate
|
||||
max_tokens: Maximum tokens allowed
|
||||
strategy: Truncation strategy ('end', 'middle', 'sentence')
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
TruncationResult with truncated content
|
||||
"""
|
||||
if not content:
|
||||
return TruncationResult(
|
||||
original_tokens=0,
|
||||
truncated_tokens=0,
|
||||
content="",
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Get original token count
|
||||
original_tokens = await self._count_tokens(content, model)
|
||||
|
||||
if original_tokens <= max_tokens:
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=original_tokens,
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Apply truncation strategy
|
||||
if strategy == "middle":
|
||||
truncated = await self._truncate_middle(content, max_tokens, model)
|
||||
elif strategy == "sentence":
|
||||
truncated = await self._truncate_sentence(content, max_tokens, model)
|
||||
else: # "end"
|
||||
truncated = await self._truncate_end(content, max_tokens, model)
|
||||
|
||||
truncated_tokens = await self._count_tokens(truncated, model)
|
||||
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=truncated_tokens,
|
||||
content=truncated,
|
||||
truncated=True,
|
||||
truncation_ratio=0.0
|
||||
if original_tokens == 0
|
||||
else 1 - (truncated_tokens / original_tokens),
|
||||
)
|
||||
|
||||
async def _truncate_end(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from end of content.
|
||||
|
||||
Simple but effective for most content types.
|
||||
"""
|
||||
# Binary search for optimal truncation point
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available_tokens = max(0, max_tokens - marker_tokens)
|
||||
|
||||
# Edge case: if no tokens available for content, return just the marker
|
||||
if available_tokens <= 0:
|
||||
return self.truncation_marker
|
||||
|
||||
# Estimate characters per token (guard against division by zero)
|
||||
content_tokens = await self._count_tokens(content, model)
|
||||
if content_tokens == 0:
|
||||
return content + self.truncation_marker
|
||||
chars_per_token = len(content) / content_tokens
|
||||
|
||||
# Start with estimated position
|
||||
estimated_chars = int(available_tokens * chars_per_token)
|
||||
truncated = content[:estimated_chars]
|
||||
|
||||
# Refine with binary search
|
||||
low, high = len(truncated) // 2, len(truncated)
|
||||
best = truncated
|
||||
|
||||
for _ in range(5): # Max 5 iterations
|
||||
mid = (low + high) // 2
|
||||
candidate = content[:mid]
|
||||
tokens = await self._count_tokens(candidate, model)
|
||||
|
||||
if tokens <= available_tokens:
|
||||
best = candidate
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid - 1
|
||||
|
||||
return best + self.truncation_marker
|
||||
|
||||
async def _truncate_middle(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from middle, keeping start and end.
|
||||
|
||||
Good for code or content where context at boundaries matters.
|
||||
"""
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available_tokens = max_tokens - marker_tokens
|
||||
|
||||
# Split between start and end
|
||||
start_tokens = int(available_tokens * self._preserve_ratio_start)
|
||||
end_tokens = available_tokens - start_tokens
|
||||
|
||||
# Get start portion
|
||||
start_content = await self._get_content_for_tokens(
|
||||
content, start_tokens, from_start=True, model=model
|
||||
)
|
||||
|
||||
# Get end portion
|
||||
end_content = await self._get_content_for_tokens(
|
||||
content, end_tokens, from_start=False, model=model
|
||||
)
|
||||
|
||||
return start_content + self.truncation_marker + end_content
|
||||
|
||||
async def _truncate_sentence(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate at sentence boundaries.
|
||||
|
||||
Produces cleaner output by not cutting mid-sentence.
|
||||
"""
|
||||
# Split into sentences
|
||||
sentences = re.split(r"(?<=[.!?])\s+", content)
|
||||
|
||||
result: list[str] = []
|
||||
total_tokens = 0
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available = max_tokens - marker_tokens
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = await self._count_tokens(sentence, model)
|
||||
if total_tokens + sentence_tokens <= available:
|
||||
result.append(sentence)
|
||||
total_tokens += sentence_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
if len(result) < len(sentences):
|
||||
return " ".join(result) + self.truncation_marker
|
||||
return " ".join(result)
|
||||
|
||||
async def _get_content_for_tokens(
|
||||
self,
|
||||
content: str,
|
||||
target_tokens: int,
|
||||
from_start: bool = True,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Get portion of content fitting within token limit."""
|
||||
if target_tokens <= 0:
|
||||
return ""
|
||||
|
||||
current_tokens = await self._count_tokens(content, model)
|
||||
if current_tokens <= target_tokens:
|
||||
return content
|
||||
|
||||
# Estimate characters (guard against division by zero)
|
||||
if current_tokens == 0:
|
||||
return content
|
||||
chars_per_token = len(content) / current_tokens
|
||||
estimated_chars = int(target_tokens * chars_per_token)
|
||||
|
||||
if from_start:
|
||||
return content[:estimated_chars]
|
||||
else:
|
||||
return content[-estimated_chars:]
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
|
||||
# Fallback estimation with model-specific ratios
|
||||
return _estimate_tokens(text, model)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""
|
||||
Compresses contexts to fit within budget constraints.
|
||||
|
||||
Uses truncation strategies to reduce context size while
|
||||
preserving the most important information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
truncation: TruncationStrategy | None = None,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context compressor.
|
||||
|
||||
Args:
|
||||
truncation: Truncation strategy to use
|
||||
calculator: Token calculator for counting
|
||||
"""
|
||||
self._truncation = truncation or TruncationStrategy(calculator)
|
||||
self._calculator = calculator
|
||||
|
||||
if calculator:
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
async def compress_context(
|
||||
self,
|
||||
context: BaseContext,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> BaseContext:
|
||||
"""
|
||||
Compress a single context to fit token limit.
|
||||
|
||||
Args:
|
||||
context: Context to compress
|
||||
max_tokens: Maximum tokens allowed
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
Compressed context (may be same object if no compression needed)
|
||||
"""
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens <= max_tokens:
|
||||
return context
|
||||
|
||||
# Choose strategy based on context type
|
||||
strategy = self._get_strategy_for_type(context.get_type())
|
||||
|
||||
result = await self._truncation.truncate_to_tokens(
|
||||
content=context.content,
|
||||
max_tokens=max_tokens,
|
||||
strategy=strategy,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Update context with truncated content
|
||||
context.content = result.content
|
||||
context.token_count = result.truncated_tokens
|
||||
context.metadata["truncated"] = True
|
||||
context.metadata["original_tokens"] = result.original_tokens
|
||||
|
||||
return context
|
||||
|
||||
async def compress_contexts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
budget: "TokenBudget",
|
||||
model: str | None = None,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Compress multiple contexts to fit within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to potentially compress
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
List of contexts (compressed as needed)
|
||||
"""
|
||||
result: list[BaseContext] = []
|
||||
|
||||
for context in contexts:
|
||||
context_type = context.get_type()
|
||||
remaining = budget.remaining(context_type)
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens > remaining:
|
||||
# Need to compress
|
||||
compressed = await self.compress_context(context, remaining, model)
|
||||
result.append(compressed)
|
||||
logger.debug(
|
||||
f"Compressed {context_type.value} context from "
|
||||
f"{current_tokens} to {compressed.token_count} tokens"
|
||||
)
|
||||
else:
|
||||
result.append(context)
|
||||
|
||||
return result
|
||||
|
||||
def _get_strategy_for_type(self, context_type: ContextType) -> str:
|
||||
"""Get optimal truncation strategy for context type."""
|
||||
strategies = {
|
||||
ContextType.SYSTEM: "end", # Keep instructions at start
|
||||
ContextType.TASK: "end", # Keep task description start
|
||||
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
|
||||
ContextType.CONVERSATION: "end", # Keep recent conversation
|
||||
ContextType.TOOL: "middle", # Keep command and result summary
|
||||
}
|
||||
return strategies.get(context_type, "end")
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
# Use model-specific estimation for consistency
|
||||
return _estimate_tokens(text, model)
|
||||
380
backend/app/services/context/config.py
Normal file
380
backend/app/services/context/config.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Context Management Engine Configuration.
|
||||
|
||||
Provides Pydantic settings for context assembly,
|
||||
token budget allocation, and caching.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ContextSettings(BaseSettings):
|
||||
"""
|
||||
Configuration for the Context Management Engine.
|
||||
|
||||
All settings can be overridden via environment variables
|
||||
with the CTX_ prefix.
|
||||
"""
|
||||
|
||||
# Budget allocation percentages (must sum to 1.0)
|
||||
budget_system: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for system prompts (5%)",
|
||||
)
|
||||
budget_task: float = Field(
|
||||
default=0.10,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for task context (10%)",
|
||||
)
|
||||
budget_knowledge: float = Field(
|
||||
default=0.40,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for RAG/knowledge (40%)",
|
||||
)
|
||||
budget_conversation: float = Field(
|
||||
default=0.20,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for conversation history (20%)",
|
||||
)
|
||||
budget_tools: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for tool descriptions (5%)",
|
||||
)
|
||||
budget_response: float = Field(
|
||||
default=0.15,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage reserved for response (15%)",
|
||||
)
|
||||
budget_buffer: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage buffer for safety margin (5%)",
|
||||
)
|
||||
|
||||
# Scoring weights
|
||||
scoring_relevance_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for relevance scoring",
|
||||
)
|
||||
scoring_recency_weight: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for recency scoring",
|
||||
)
|
||||
scoring_priority_weight: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for priority scoring",
|
||||
)
|
||||
|
||||
# Recency decay settings
|
||||
recency_decay_hours: float = Field(
|
||||
default=24.0,
|
||||
gt=0.0,
|
||||
description="Hours until recency score decays to 50%",
|
||||
)
|
||||
recency_max_age_hours: float = Field(
|
||||
default=168.0,
|
||||
gt=0.0,
|
||||
description="Hours until context is considered stale (7 days)",
|
||||
)
|
||||
|
||||
# Compression settings
|
||||
compression_threshold: float = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Compress when budget usage exceeds this percentage",
|
||||
)
|
||||
truncation_marker: str = Field(
|
||||
default="\n\n[...content truncated...]\n\n",
|
||||
description="Marker text to insert where content was truncated",
|
||||
)
|
||||
truncation_preserve_ratio: float = Field(
|
||||
default=0.7,
|
||||
ge=0.1,
|
||||
le=0.9,
|
||||
description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)",
|
||||
)
|
||||
truncation_min_content_length: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Minimum content length in characters before truncation applies",
|
||||
)
|
||||
summary_model_group: str = Field(
|
||||
default="fast",
|
||||
description="Model group to use for summarization",
|
||||
)
|
||||
|
||||
# Caching settings
|
||||
cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable Redis caching for assembled contexts",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=3600,
|
||||
ge=60,
|
||||
le=86400,
|
||||
description="Cache TTL in seconds (1 hour default, max 24 hours)",
|
||||
)
|
||||
cache_prefix: str = Field(
|
||||
default="ctx",
|
||||
description="Redis key prefix for context cache",
|
||||
)
|
||||
cache_memory_max_items: int = Field(
|
||||
default=1000,
|
||||
ge=100,
|
||||
le=100000,
|
||||
description="Maximum items in memory fallback cache when Redis unavailable",
|
||||
)
|
||||
|
||||
# Performance settings
|
||||
max_assembly_time_ms: int = Field(
|
||||
default=2000,
|
||||
ge=10,
|
||||
le=30000,
|
||||
description="Maximum time for context assembly in milliseconds. "
|
||||
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
|
||||
)
|
||||
parallel_scoring: bool = Field(
|
||||
default=True,
|
||||
description="Score contexts in parallel for better performance",
|
||||
)
|
||||
max_parallel_scores: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum number of contexts to score in parallel",
|
||||
)
|
||||
|
||||
# Knowledge retrieval settings
|
||||
knowledge_search_type: str = Field(
|
||||
default="hybrid",
|
||||
description="Default search type for knowledge retrieval",
|
||||
)
|
||||
knowledge_max_results: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum knowledge chunks to retrieve",
|
||||
)
|
||||
knowledge_min_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum relevance score for knowledge",
|
||||
)
|
||||
|
||||
# Relevance scoring settings
|
||||
relevance_keyword_fallback_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Maximum score for keyword-based fallback scoring (when semantic unavailable)",
|
||||
)
|
||||
relevance_semantic_max_chars: int = Field(
|
||||
default=2000,
|
||||
ge=100,
|
||||
le=10000,
|
||||
description="Maximum content length in chars for semantic similarity computation",
|
||||
)
|
||||
|
||||
# Diversity/ranking settings
|
||||
diversity_max_per_source: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
le=20,
|
||||
description="Maximum contexts from the same source in diversity reranking",
|
||||
)
|
||||
|
||||
# Conversation history settings
|
||||
conversation_max_turns: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum conversation turns to include",
|
||||
)
|
||||
conversation_recent_priority: bool = Field(
|
||||
default=True,
|
||||
description="Prioritize recent conversation turns",
|
||||
)
|
||||
|
||||
@field_validator("knowledge_search_type")
|
||||
@classmethod
|
||||
def validate_search_type(cls, v: str) -> str:
|
||||
"""Validate search type is valid."""
|
||||
valid_types = {"semantic", "keyword", "hybrid"}
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"search_type must be one of: {valid_types}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_budget_allocation(self) -> "ContextSettings":
|
||||
"""Validate that budget percentages sum to 1.0."""
|
||||
total = (
|
||||
self.budget_system
|
||||
+ self.budget_task
|
||||
+ self.budget_knowledge
|
||||
+ self.budget_conversation
|
||||
+ self.budget_tools
|
||||
+ self.budget_response
|
||||
+ self.budget_buffer
|
||||
)
|
||||
# Allow small floating point error
|
||||
if abs(total - 1.0) > 0.001:
|
||||
raise ValueError(
|
||||
f"Budget percentages must sum to 1.0, got {total:.3f}. "
|
||||
f"Current allocation: system={self.budget_system}, task={self.budget_task}, "
|
||||
f"knowledge={self.budget_knowledge}, conversation={self.budget_conversation}, "
|
||||
f"tools={self.budget_tools}, response={self.budget_response}, buffer={self.budget_buffer}"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_scoring_weights(self) -> "ContextSettings":
|
||||
"""Validate that scoring weights sum to 1.0."""
|
||||
total = (
|
||||
self.scoring_relevance_weight
|
||||
+ self.scoring_recency_weight
|
||||
+ self.scoring_priority_weight
|
||||
)
|
||||
# Allow small floating point error
|
||||
if abs(total - 1.0) > 0.001:
|
||||
raise ValueError(
|
||||
f"Scoring weights must sum to 1.0, got {total:.3f}. "
|
||||
f"Current weights: relevance={self.scoring_relevance_weight}, "
|
||||
f"recency={self.scoring_recency_weight}, priority={self.scoring_priority_weight}"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_budget_allocation(self) -> dict[str, float]:
|
||||
"""Get budget allocation as a dictionary."""
|
||||
return {
|
||||
"system": self.budget_system,
|
||||
"task": self.budget_task,
|
||||
"knowledge": self.budget_knowledge,
|
||||
"conversation": self.budget_conversation,
|
||||
"tools": self.budget_tools,
|
||||
"response": self.budget_response,
|
||||
"buffer": self.budget_buffer,
|
||||
}
|
||||
|
||||
def get_scoring_weights(self) -> dict[str, float]:
|
||||
"""Get scoring weights as a dictionary."""
|
||||
return {
|
||||
"relevance": self.scoring_relevance_weight,
|
||||
"recency": self.scoring_recency_weight,
|
||||
"priority": self.scoring_priority_weight,
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert settings to dictionary for logging/debugging."""
|
||||
return {
|
||||
"budget": self.get_budget_allocation(),
|
||||
"scoring": self.get_scoring_weights(),
|
||||
"compression": {
|
||||
"threshold": self.compression_threshold,
|
||||
"summary_model_group": self.summary_model_group,
|
||||
"truncation_marker": self.truncation_marker,
|
||||
"truncation_preserve_ratio": self.truncation_preserve_ratio,
|
||||
"truncation_min_content_length": self.truncation_min_content_length,
|
||||
},
|
||||
"cache": {
|
||||
"enabled": self.cache_enabled,
|
||||
"ttl_seconds": self.cache_ttl_seconds,
|
||||
"prefix": self.cache_prefix,
|
||||
"memory_max_items": self.cache_memory_max_items,
|
||||
},
|
||||
"performance": {
|
||||
"max_assembly_time_ms": self.max_assembly_time_ms,
|
||||
"parallel_scoring": self.parallel_scoring,
|
||||
"max_parallel_scores": self.max_parallel_scores,
|
||||
},
|
||||
"knowledge": {
|
||||
"search_type": self.knowledge_search_type,
|
||||
"max_results": self.knowledge_max_results,
|
||||
"min_score": self.knowledge_min_score,
|
||||
},
|
||||
"relevance": {
|
||||
"keyword_fallback_weight": self.relevance_keyword_fallback_weight,
|
||||
"semantic_max_chars": self.relevance_semantic_max_chars,
|
||||
},
|
||||
"diversity": {
|
||||
"max_per_source": self.diversity_max_per_source,
|
||||
},
|
||||
"conversation": {
|
||||
"max_turns": self.conversation_max_turns,
|
||||
"recent_priority": self.conversation_recent_priority,
|
||||
},
|
||||
}
|
||||
|
||||
model_config = {
|
||||
"env_prefix": "CTX_",
|
||||
"env_file": "../.env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Thread-safe singleton pattern
|
||||
_settings: ContextSettings | None = None
|
||||
_settings_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_context_settings() -> ContextSettings:
|
||||
"""
|
||||
Get the global ContextSettings instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Returns:
|
||||
ContextSettings instance
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
with _settings_lock:
|
||||
if _settings is None:
|
||||
_settings = ContextSettings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_context_settings() -> None:
|
||||
"""
|
||||
Reset the global settings instance.
|
||||
|
||||
Primarily used for testing.
|
||||
"""
|
||||
global _settings
|
||||
with _settings_lock:
|
||||
_settings = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_default_settings() -> ContextSettings:
|
||||
"""
|
||||
Get default settings (cached).
|
||||
|
||||
Use this for read-only access to defaults.
|
||||
For mutable access, use get_context_settings().
|
||||
"""
|
||||
return ContextSettings()
|
||||
485
backend/app/services/context/engine.py
Normal file
485
backend/app/services/context/engine.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
Context Management Engine.
|
||||
|
||||
Main orchestration layer for context assembly and optimization.
|
||||
Provides a high-level API for assembling optimized context for LLM requests.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .assembly import ContextPipeline
|
||||
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from .cache import ContextCache
|
||||
from .compression import ContextCompressor
|
||||
from .config import ContextSettings, get_context_settings
|
||||
from .prioritization import ContextRanker
|
||||
from .scoring import CompositeScorer
|
||||
from .types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextEngine:
|
||||
"""
|
||||
Main context management engine.
|
||||
|
||||
Provides high-level API for context assembly and optimization.
|
||||
Integrates all components: scoring, ranking, compression, formatting, and caching.
|
||||
|
||||
Usage:
|
||||
engine = ContextEngine(mcp_manager=mcp, redis=redis)
|
||||
|
||||
# Assemble context for an LLM request
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="implement user authentication",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="You are an expert developer.",
|
||||
knowledge_query="authentication best practices",
|
||||
)
|
||||
|
||||
# Use the assembled context
|
||||
print(result.content)
|
||||
print(f"Tokens: {result.total_tokens}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context engine.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||
redis: Redis connection for caching
|
||||
settings: Context settings
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._settings = settings or get_context_settings()
|
||||
|
||||
# Initialize components
|
||||
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
|
||||
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
|
||||
self._compressor = ContextCompressor(calculator=self._calculator)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
self._cache = ContextCache(redis=redis, settings=self._settings)
|
||||
|
||||
# Pipeline for assembly
|
||||
self._pipeline = ContextPipeline(
|
||||
mcp_manager=mcp_manager,
|
||||
settings=self._settings,
|
||||
calculator=self._calculator,
|
||||
scorer=self._scorer,
|
||||
ranker=self._ranker,
|
||||
compressor=self._compressor,
|
||||
)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""
|
||||
Set MCP manager for all components.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._calculator.set_mcp_manager(mcp_manager)
|
||||
self._scorer.set_mcp_manager(mcp_manager)
|
||||
self._pipeline.set_mcp_manager(mcp_manager)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""
|
||||
Set Redis connection for caching.
|
||||
|
||||
Args:
|
||||
redis: Redis connection
|
||||
"""
|
||||
self._cache.set_redis(redis)
|
||||
|
||||
async def assemble_context(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_description: str | None = None,
|
||||
knowledge_query: str | None = None,
|
||||
knowledge_limit: int = 10,
|
||||
conversation_history: list[dict[str, str]] | None = None,
|
||||
tool_results: list[dict[str, Any]] | None = None,
|
||||
custom_contexts: list[BaseContext] | None = None,
|
||||
custom_budget: TokenBudget | None = None,
|
||||
compress: bool = True,
|
||||
format_output: bool = True,
|
||||
use_cache: bool = True,
|
||||
) -> AssembledContext:
|
||||
"""
|
||||
Assemble optimized context for an LLM request.
|
||||
|
||||
This is the main entry point for context management.
|
||||
It gathers context from various sources, scores and ranks them,
|
||||
compresses if needed, and formats for the target model.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: User's query or current request
|
||||
model: Target model name
|
||||
max_tokens: Maximum context tokens (uses model default if None)
|
||||
system_prompt: System prompt/instructions
|
||||
task_description: Current task description
|
||||
knowledge_query: Query for knowledge base search
|
||||
knowledge_limit: Max number of knowledge results
|
||||
conversation_history: List of {"role": str, "content": str}
|
||||
tool_results: List of tool results to include
|
||||
custom_contexts: Additional custom contexts
|
||||
custom_budget: Custom token budget
|
||||
compress: Whether to apply compression
|
||||
format_output: Whether to format for the model
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
AssembledContext with optimized content
|
||||
|
||||
Raises:
|
||||
AssemblyTimeoutError: If assembly exceeds timeout
|
||||
BudgetExceededError: If context exceeds budget
|
||||
"""
|
||||
# Gather all contexts
|
||||
contexts: list[BaseContext] = []
|
||||
|
||||
# 1. System context
|
||||
if system_prompt:
|
||||
contexts.append(
|
||||
SystemContext(
|
||||
content=system_prompt,
|
||||
source="system_prompt",
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Task context
|
||||
if task_description:
|
||||
contexts.append(
|
||||
TaskContext(
|
||||
content=task_description,
|
||||
source=f"task:{project_id}:{agent_id}",
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Knowledge context from Knowledge Base
|
||||
if knowledge_query and self._mcp:
|
||||
knowledge_contexts = await self._fetch_knowledge(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
query=knowledge_query,
|
||||
limit=knowledge_limit,
|
||||
)
|
||||
contexts.extend(knowledge_contexts)
|
||||
|
||||
# 4. Conversation history
|
||||
if conversation_history:
|
||||
contexts.extend(self._convert_conversation(conversation_history))
|
||||
|
||||
# 5. Tool results
|
||||
if tool_results:
|
||||
contexts.extend(self._convert_tool_results(tool_results))
|
||||
|
||||
# 6. Custom contexts
|
||||
if custom_contexts:
|
||||
contexts.extend(custom_contexts)
|
||||
|
||||
# Check cache if enabled
|
||||
fingerprint: str | None = None
|
||||
if use_cache and self._cache.is_enabled:
|
||||
# Include project_id and agent_id for tenant isolation
|
||||
fingerprint = self._cache.compute_fingerprint(
|
||||
contexts, query, model, project_id=project_id, agent_id=agent_id
|
||||
)
|
||||
cached = await self._cache.get_assembled(fingerprint)
|
||||
if cached:
|
||||
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||
return cached
|
||||
|
||||
# Run assembly pipeline
|
||||
result = await self._pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
custom_budget=custom_budget,
|
||||
compress=compress,
|
||||
format_output=format_output,
|
||||
)
|
||||
|
||||
# Cache result if enabled (reuse fingerprint computed above)
|
||||
if use_cache and self._cache.is_enabled and fingerprint is not None:
|
||||
await self._cache.set_assembled(fingerprint, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _fetch_knowledge(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
) -> list[KnowledgeContext]:
|
||||
"""
|
||||
Fetch relevant knowledge from Knowledge Base via MCP.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of KnowledgeContext instances
|
||||
"""
|
||||
if not self._mcp:
|
||||
return []
|
||||
|
||||
try:
|
||||
result = await self._mcp.call_tool(
|
||||
"knowledge-base",
|
||||
"search_knowledge",
|
||||
{
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"query": query,
|
||||
"search_type": "hybrid",
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
# Check both ToolResult.success AND response success
|
||||
if not result.success:
|
||||
logger.warning(f"Knowledge search failed: {result.error}")
|
||||
return []
|
||||
|
||||
if not isinstance(result.data, dict) or not result.data.get(
|
||||
"success", True
|
||||
):
|
||||
logger.warning("Knowledge search returned unsuccessful response")
|
||||
return []
|
||||
|
||||
contexts = []
|
||||
results = result.data.get("results", [])
|
||||
for chunk in results:
|
||||
contexts.append(
|
||||
KnowledgeContext(
|
||||
content=chunk.get("content", ""),
|
||||
source=chunk.get("source_path", "unknown"),
|
||||
relevance_score=chunk.get("score", 0.0),
|
||||
metadata={
|
||||
"chunk_id": chunk.get(
|
||||
"id"
|
||||
), # Server returns 'id' not 'chunk_id'
|
||||
"document_id": chunk.get("document_id"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
|
||||
return contexts
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||
return []
|
||||
|
||||
def _convert_conversation(
|
||||
self,
|
||||
history: list[dict[str, str]],
|
||||
) -> list[ConversationContext]:
|
||||
"""
|
||||
Convert conversation history to ConversationContext instances.
|
||||
|
||||
Args:
|
||||
history: List of {"role": str, "content": str}
|
||||
|
||||
Returns:
|
||||
List of ConversationContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for i, turn in enumerate(history):
|
||||
role_str = turn.get("role", "user").lower()
|
||||
role = (
|
||||
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
||||
)
|
||||
|
||||
contexts.append(
|
||||
ConversationContext(
|
||||
content=turn.get("content", ""),
|
||||
source=f"conversation:{i}",
|
||||
role=role,
|
||||
metadata={"role": role_str, "turn": i},
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
def _convert_tool_results(
|
||||
self,
|
||||
results: list[dict[str, Any]],
|
||||
) -> list[ToolContext]:
|
||||
"""
|
||||
Convert tool results to ToolContext instances.
|
||||
|
||||
Args:
|
||||
results: List of tool result dictionaries
|
||||
|
||||
Returns:
|
||||
List of ToolContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for result in results:
|
||||
tool_name = result.get("tool_name", "unknown")
|
||||
content = result.get("content", result.get("result", ""))
|
||||
|
||||
# Handle dict content
|
||||
if isinstance(content, dict):
|
||||
import json
|
||||
|
||||
content = json.dumps(content, indent=2)
|
||||
|
||||
contexts.append(
|
||||
ToolContext(
|
||||
content=str(content),
|
||||
source=f"tool:{tool_name}",
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"status": result.get("status", "success"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def get_budget_for_model(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Get the token budget for a specific model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
TokenBudget instance
|
||||
"""
|
||||
if max_tokens:
|
||||
return self._allocator.create_budget(max_tokens)
|
||||
return self._allocator.create_budget_for_model(model)
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
content: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens in content.
|
||||
|
||||
Args:
|
||||
content: Content to count
|
||||
model: Model for model-specific tokenization
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
# Check cache first
|
||||
cached = await self._cache.get_token_count(content, model)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
count = await self._calculator.count_tokens(content, model)
|
||||
|
||||
# Cache the result
|
||||
await self._cache.set_token_count(content, count, model)
|
||||
|
||||
return count
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
project_id: str | None = None,
|
||||
pattern: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
project_id: Invalidate all cache for a project
|
||||
pattern: Custom pattern to match
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if pattern:
|
||||
return await self._cache.invalidate(pattern)
|
||||
elif project_id:
|
||||
return await self._cache.invalidate(f"*{project_id}*")
|
||||
else:
|
||||
return await self._cache.clear_all()
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get engine statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with engine stats
|
||||
"""
|
||||
return {
|
||||
"cache": await self._cache.get_stats(),
|
||||
"settings": {
|
||||
"compression_threshold": self._settings.compression_threshold,
|
||||
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
|
||||
"cache_enabled": self._settings.cache_enabled,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Convenience factory function
|
||||
def create_context_engine(
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> ContextEngine:
|
||||
"""
|
||||
Create a context engine instance.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager
|
||||
redis: Redis connection
|
||||
settings: Context settings
|
||||
|
||||
Returns:
|
||||
Configured ContextEngine instance
|
||||
"""
|
||||
return ContextEngine(
|
||||
mcp_manager=mcp_manager,
|
||||
redis=redis,
|
||||
settings=settings,
|
||||
)
|
||||
354
backend/app/services/context/exceptions.py
Normal file
354
backend/app/services/context/exceptions.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
Context Management Engine Exceptions.
|
||||
|
||||
Provides a hierarchy of exceptions for context assembly,
|
||||
token budget management, and related operations.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ContextError(Exception):
|
||||
"""
|
||||
Base exception for all context management errors.
|
||||
|
||||
All context-related exceptions should inherit from this class
|
||||
to allow for catch-all handling when needed.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, details: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Initialize context error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
details: Optional dict with additional error context
|
||||
"""
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(message)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert exception to dictionary for logging/serialization."""
|
||||
return {
|
||||
"error_type": self.__class__.__name__,
|
||||
"message": self.message,
|
||||
"details": self.details,
|
||||
}
|
||||
|
||||
|
||||
class BudgetExceededError(ContextError):
|
||||
"""
|
||||
Raised when token budget is exceeded.
|
||||
|
||||
This occurs when the assembled context would exceed the
|
||||
allocated token budget for a specific context type or total.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Token budget exceeded",
|
||||
allocated: int = 0,
|
||||
requested: int = 0,
|
||||
context_type: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize budget exceeded error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
allocated: Tokens allocated for this context type
|
||||
requested: Tokens requested
|
||||
context_type: Type of context that exceeded budget
|
||||
"""
|
||||
details: dict[str, Any] = {
|
||||
"allocated": allocated,
|
||||
"requested": requested,
|
||||
"overage": requested - allocated,
|
||||
}
|
||||
if context_type:
|
||||
details["context_type"] = context_type
|
||||
|
||||
super().__init__(message, details)
|
||||
self.allocated = allocated
|
||||
self.requested = requested
|
||||
self.context_type = context_type
|
||||
|
||||
|
||||
class TokenCountError(ContextError):
|
||||
"""
|
||||
Raised when token counting fails.
|
||||
|
||||
This typically occurs when the LLM Gateway token counting
|
||||
service is unavailable or returns an error.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to count tokens",
|
||||
model: str | None = None,
|
||||
text_length: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize token count error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
model: Model for which counting was attempted
|
||||
text_length: Length of text that failed to count
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if model:
|
||||
details["model"] = model
|
||||
if text_length is not None:
|
||||
details["text_length"] = text_length
|
||||
|
||||
super().__init__(message, details)
|
||||
self.model = model
|
||||
self.text_length = text_length
|
||||
|
||||
|
||||
class CompressionError(ContextError):
|
||||
"""
|
||||
Raised when context compression fails.
|
||||
|
||||
This can occur when summarization or truncation cannot
|
||||
reduce content to fit within the budget.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to compress context",
|
||||
original_tokens: int | None = None,
|
||||
target_tokens: int | None = None,
|
||||
achieved_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize compression error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
original_tokens: Tokens before compression
|
||||
target_tokens: Target token count
|
||||
achieved_tokens: Tokens achieved after compression attempt
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if original_tokens is not None:
|
||||
details["original_tokens"] = original_tokens
|
||||
if target_tokens is not None:
|
||||
details["target_tokens"] = target_tokens
|
||||
if achieved_tokens is not None:
|
||||
details["achieved_tokens"] = achieved_tokens
|
||||
|
||||
super().__init__(message, details)
|
||||
self.original_tokens = original_tokens
|
||||
self.target_tokens = target_tokens
|
||||
self.achieved_tokens = achieved_tokens
|
||||
|
||||
|
||||
class AssemblyTimeoutError(ContextError):
|
||||
"""
|
||||
Raised when context assembly exceeds time limit.
|
||||
|
||||
Context assembly must complete within a configurable
|
||||
time limit to maintain responsiveness.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Context assembly timed out",
|
||||
timeout_ms: int = 0,
|
||||
elapsed_ms: float = 0.0,
|
||||
stage: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize assembly timeout error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
timeout_ms: Configured timeout in milliseconds
|
||||
elapsed_ms: Actual elapsed time in milliseconds
|
||||
stage: Pipeline stage where timeout occurred
|
||||
"""
|
||||
details: dict[str, Any] = {
|
||||
"timeout_ms": timeout_ms,
|
||||
"elapsed_ms": round(elapsed_ms, 2),
|
||||
}
|
||||
if stage:
|
||||
details["stage"] = stage
|
||||
|
||||
super().__init__(message, details)
|
||||
self.timeout_ms = timeout_ms
|
||||
self.elapsed_ms = elapsed_ms
|
||||
self.stage = stage
|
||||
|
||||
|
||||
class ScoringError(ContextError):
|
||||
"""
|
||||
Raised when context scoring fails.
|
||||
|
||||
This occurs when relevance, recency, or priority scoring
|
||||
encounters an error.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to score context",
|
||||
scorer_type: str | None = None,
|
||||
context_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize scoring error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
scorer_type: Type of scorer that failed
|
||||
context_id: ID of context being scored
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if scorer_type:
|
||||
details["scorer_type"] = scorer_type
|
||||
if context_id:
|
||||
details["context_id"] = context_id
|
||||
|
||||
super().__init__(message, details)
|
||||
self.scorer_type = scorer_type
|
||||
self.context_id = context_id
|
||||
|
||||
|
||||
class FormattingError(ContextError):
|
||||
"""
|
||||
Raised when context formatting fails.
|
||||
|
||||
This occurs when converting assembled context to
|
||||
model-specific format fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to format context",
|
||||
model: str | None = None,
|
||||
adapter: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize formatting error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
model: Target model
|
||||
adapter: Adapter that failed
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if model:
|
||||
details["model"] = model
|
||||
if adapter:
|
||||
details["adapter"] = adapter
|
||||
|
||||
super().__init__(message, details)
|
||||
self.model = model
|
||||
self.adapter = adapter
|
||||
|
||||
|
||||
class CacheError(ContextError):
|
||||
"""
|
||||
Raised when cache operations fail.
|
||||
|
||||
This is typically non-fatal and should be handled
|
||||
gracefully by falling back to recomputation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Cache operation failed",
|
||||
operation: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize cache error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
operation: Cache operation that failed (get, set, delete)
|
||||
cache_key: Key involved in the failed operation
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if operation:
|
||||
details["operation"] = operation
|
||||
if cache_key:
|
||||
details["cache_key"] = cache_key
|
||||
|
||||
super().__init__(message, details)
|
||||
self.operation = operation
|
||||
self.cache_key = cache_key
|
||||
|
||||
|
||||
class ContextNotFoundError(ContextError):
|
||||
"""
|
||||
Raised when expected context is not found.
|
||||
|
||||
This occurs when required context sources return
|
||||
no results or are unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Required context not found",
|
||||
source: str | None = None,
|
||||
query: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context not found error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
source: Source that returned no results
|
||||
query: Query used to search
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if source:
|
||||
details["source"] = source
|
||||
if query:
|
||||
details["query"] = query
|
||||
|
||||
super().__init__(message, details)
|
||||
self.source = source
|
||||
self.query = query
|
||||
|
||||
|
||||
class InvalidContextError(ContextError):
|
||||
"""
|
||||
Raised when context data is invalid.
|
||||
|
||||
This occurs when context content or metadata
|
||||
fails validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid context data",
|
||||
field: str | None = None,
|
||||
value: Any | None = None,
|
||||
reason: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize invalid context error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
field: Field that is invalid
|
||||
value: Invalid value (may be redacted for security)
|
||||
reason: Reason for invalidity
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if field:
|
||||
details["field"] = field
|
||||
if value is not None:
|
||||
# Avoid logging potentially sensitive values
|
||||
details["value_type"] = type(value).__name__
|
||||
if reason:
|
||||
details["reason"] = reason
|
||||
|
||||
super().__init__(message, details)
|
||||
self.field = field
|
||||
self.value = value
|
||||
self.reason = reason
|
||||
12
backend/app/services/context/prioritization/__init__.py
Normal file
12
backend/app/services/context/prioritization/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Context Prioritization Module.
|
||||
|
||||
Provides context ranking and selection.
|
||||
"""
|
||||
|
||||
from .ranker import ContextRanker, RankingResult
|
||||
|
||||
__all__ = [
|
||||
"ContextRanker",
|
||||
"RankingResult",
|
||||
]
|
||||
374
backend/app/services/context/prioritization/ranker.py
Normal file
374
backend/app/services/context/prioritization/ranker.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Context Ranker for Context Management.
|
||||
|
||||
Ranks and selects contexts based on scores and budget constraints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||
from ..types import BaseContext, ContextPriority
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankingResult:
|
||||
"""Result of context ranking and selection."""
|
||||
|
||||
selected: list[ScoredContext]
|
||||
excluded: list[ScoredContext]
|
||||
total_tokens: int
|
||||
selection_stats: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def selected_contexts(self) -> list[BaseContext]:
|
||||
"""Get just the context objects (not scored wrappers)."""
|
||||
return [s.context for s in self.selected]
|
||||
|
||||
|
||||
class ContextRanker:
|
||||
"""
|
||||
Ranks and selects contexts within budget constraints.
|
||||
|
||||
Uses greedy selection to maximize total score
|
||||
while respecting token budgets per context type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorer: CompositeScorer | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context ranker.
|
||||
|
||||
Args:
|
||||
scorer: Composite scorer for scoring contexts
|
||||
calculator: Token calculator for counting tokens
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._scorer = scorer or CompositeScorer()
|
||||
self._calculator = calculator or TokenCalculator()
|
||||
|
||||
def set_scorer(self, scorer: CompositeScorer) -> None:
|
||||
"""Set the scorer."""
|
||||
self._scorer = scorer
|
||||
|
||||
def set_calculator(self, calculator: TokenCalculator) -> None:
|
||||
"""Set the token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def rank(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
budget: TokenBudget,
|
||||
model: str | None = None,
|
||||
ensure_required: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> RankingResult:
|
||||
"""
|
||||
Rank and select contexts within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
ensure_required: If True, always include CRITICAL priority contexts
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
RankingResult with selected and excluded contexts
|
||||
"""
|
||||
if not contexts:
|
||||
return RankingResult(
|
||||
selected=[],
|
||||
excluded=[],
|
||||
total_tokens=0,
|
||||
selection_stats={"total_contexts": 0},
|
||||
)
|
||||
|
||||
# 1. Ensure all contexts have token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# 2. Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# 3. Separate required (CRITICAL priority) from optional
|
||||
required: list[ScoredContext] = []
|
||||
optional: list[ScoredContext] = []
|
||||
|
||||
if ensure_required:
|
||||
for sc in scored_contexts:
|
||||
# CRITICAL priority (150) contexts are always included
|
||||
if sc.context.priority >= ContextPriority.CRITICAL.value:
|
||||
required.append(sc)
|
||||
else:
|
||||
optional.append(sc)
|
||||
else:
|
||||
optional = list(scored_contexts)
|
||||
|
||||
# 4. Sort optional by score (highest first)
|
||||
optional.sort(reverse=True)
|
||||
|
||||
# 5. Greedy selection
|
||||
selected: list[ScoredContext] = []
|
||||
excluded: list[ScoredContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
# Calculate the usable budget (total minus reserved portions)
|
||||
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||
|
||||
# Guard against invalid budget configuration
|
||||
if usable_budget <= 0:
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"Invalid budget configuration: no usable tokens available. "
|
||||
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
||||
f"buffer={budget.buffer}"
|
||||
),
|
||||
allocated=budget.total,
|
||||
requested=0,
|
||||
context_type="CONFIGURATION_ERROR",
|
||||
)
|
||||
|
||||
# First, try to fit required contexts
|
||||
for sc in required:
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
# Force-fit CRITICAL contexts if needed, but check total budget first
|
||||
if total_tokens + token_count > usable_budget:
|
||||
# Even CRITICAL contexts cannot exceed total model context window
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"CRITICAL contexts exceed total budget. "
|
||||
f"Context '{sc.context.source}' ({token_count} tokens) "
|
||||
f"would exceed usable budget of {usable_budget} tokens."
|
||||
),
|
||||
allocated=usable_budget,
|
||||
requested=total_tokens + token_count,
|
||||
context_type="CRITICAL_OVERFLOW",
|
||||
)
|
||||
|
||||
budget.allocate(context_type, token_count, force=True)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
logger.warning(
|
||||
f"Force-fitted CRITICAL context: {sc.context.source} "
|
||||
f"({token_count} tokens)"
|
||||
)
|
||||
|
||||
# Then, greedily add optional contexts
|
||||
for sc in optional:
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
excluded.append(sc)
|
||||
|
||||
# Build stats
|
||||
stats = {
|
||||
"total_contexts": len(contexts),
|
||||
"required_count": len(required),
|
||||
"selected_count": len(selected),
|
||||
"excluded_count": len(excluded),
|
||||
"total_tokens": total_tokens,
|
||||
"by_type": self._count_by_type(selected),
|
||||
}
|
||||
|
||||
return RankingResult(
|
||||
selected=selected,
|
||||
excluded=excluded,
|
||||
total_tokens=total_tokens,
|
||||
selection_stats=stats,
|
||||
)
|
||||
|
||||
async def rank_simple(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Simple ranking without budget per type.
|
||||
|
||||
Selects top contexts by score until max tokens reached.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
max_tokens: Maximum total tokens
|
||||
model: Model for token counting
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Selected contexts (in score order)
|
||||
"""
|
||||
if not contexts:
|
||||
return []
|
||||
|
||||
# Ensure token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored_contexts.sort(reverse=True)
|
||||
|
||||
# Greedy selection
|
||||
selected: list[BaseContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
for sc in scored_contexts:
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
if total_tokens + token_count <= max_tokens:
|
||||
selected.append(sc.context)
|
||||
total_tokens += token_count
|
||||
|
||||
return selected
|
||||
|
||||
def _get_valid_token_count(self, context: BaseContext) -> int:
|
||||
"""
|
||||
Get validated token count from a context.
|
||||
|
||||
Ensures token_count is set (not None) and non-negative to prevent
|
||||
budget bypass attacks where:
|
||||
- None would be treated as 0 (allowing huge contexts to slip through)
|
||||
- Negative values would corrupt budget tracking
|
||||
|
||||
Args:
|
||||
context: Context to get token count from
|
||||
|
||||
Returns:
|
||||
Valid non-negative token count
|
||||
|
||||
Raises:
|
||||
ValueError: If token_count is None or negative
|
||||
"""
|
||||
if context.token_count is None:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has no token count. "
|
||||
"Ensure _ensure_token_counts() is called before ranking."
|
||||
)
|
||||
if context.token_count < 0:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has invalid negative token count: "
|
||||
f"{context.token_count}"
|
||||
)
|
||||
return context.token_count
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Ensure all contexts have token counts.
|
||||
|
||||
Counts tokens in parallel for contexts that don't have counts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to check
|
||||
model: Model for token counting
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Find contexts needing counts
|
||||
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
|
||||
|
||||
if not contexts_needing_counts:
|
||||
return
|
||||
|
||||
# Count all in parallel
|
||||
tasks = [
|
||||
self._calculator.count_tokens(ctx.content, model)
|
||||
for ctx in contexts_needing_counts
|
||||
]
|
||||
counts = await asyncio.gather(*tasks)
|
||||
|
||||
# Assign counts back
|
||||
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
|
||||
ctx.token_count = count
|
||||
|
||||
def _count_by_type(
|
||||
self, scored_contexts: list[ScoredContext]
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""Count selected contexts by type."""
|
||||
by_type: dict[str, dict[str, int]] = {}
|
||||
|
||||
for sc in scored_contexts:
|
||||
type_name = sc.context.get_type().value
|
||||
if type_name not in by_type:
|
||||
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||
by_type[type_name]["count"] += 1
|
||||
# Use validated token count (already validated during ranking)
|
||||
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||
|
||||
return by_type
|
||||
|
||||
async def rerank_for_diversity(
|
||||
self,
|
||||
scored_contexts: list[ScoredContext],
|
||||
max_per_source: int | None = None,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Rerank to ensure source diversity.
|
||||
|
||||
Prevents too many items from the same source.
|
||||
|
||||
Args:
|
||||
scored_contexts: Already scored contexts
|
||||
max_per_source: Maximum items per source (uses settings if None)
|
||||
|
||||
Returns:
|
||||
Reranked contexts
|
||||
"""
|
||||
# Use provided value or fall back to settings
|
||||
effective_max = (
|
||||
max_per_source
|
||||
if max_per_source is not None
|
||||
else self._settings.diversity_max_per_source
|
||||
)
|
||||
|
||||
source_counts: dict[str, int] = {}
|
||||
result: list[ScoredContext] = []
|
||||
deferred: list[ScoredContext] = []
|
||||
|
||||
for sc in scored_contexts:
|
||||
source = sc.context.source
|
||||
current_count = source_counts.get(source, 0)
|
||||
|
||||
if current_count < effective_max:
|
||||
result.append(sc)
|
||||
source_counts[source] = current_count + 1
|
||||
else:
|
||||
deferred.append(sc)
|
||||
|
||||
# Add deferred items at the end
|
||||
result.extend(deferred)
|
||||
return result
|
||||
21
backend/app/services/context/scoring/__init__.py
Normal file
21
backend/app/services/context/scoring/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Context Scoring Module.
|
||||
|
||||
Provides scoring strategies for context prioritization.
|
||||
"""
|
||||
|
||||
from .base import BaseScorer, ScorerProtocol
|
||||
from .composite import CompositeScorer, ScoredContext
|
||||
from .priority import PriorityScorer
|
||||
from .recency import RecencyScorer
|
||||
from .relevance import RelevanceScorer
|
||||
|
||||
__all__ = [
|
||||
"BaseScorer",
|
||||
"CompositeScorer",
|
||||
"PriorityScorer",
|
||||
"RecencyScorer",
|
||||
"RelevanceScorer",
|
||||
"ScoredContext",
|
||||
"ScorerProtocol",
|
||||
]
|
||||
99
backend/app/services/context/scoring/base.py
Normal file
99
backend/app/services/context/scoring/base.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Base Scorer Protocol and Types.
|
||||
|
||||
Defines the interface for context scoring implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from ..types import BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScorerProtocol(Protocol):
|
||||
"""Protocol for context scorers."""
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score a context item.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BaseScorer(ABC):
|
||||
"""
|
||||
Abstract base class for context scorers.
|
||||
|
||||
Provides common functionality and interface for
|
||||
different scoring strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, weight: float = 1.0) -> None:
|
||||
"""
|
||||
Initialize scorer.
|
||||
|
||||
Args:
|
||||
weight: Weight for this scorer in composite scoring
|
||||
"""
|
||||
self._weight = weight
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
"""Get scorer weight."""
|
||||
return self._weight
|
||||
|
||||
@weight.setter
|
||||
def weight(self, value: float) -> None:
|
||||
"""Set scorer weight."""
|
||||
if not 0.0 <= value <= 1.0:
|
||||
raise ValueError("Weight must be between 0.0 and 1.0")
|
||||
self._weight = value
|
||||
|
||||
@abstractmethod
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score a context item.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
...
|
||||
|
||||
def normalize_score(self, score: float) -> float:
|
||||
"""
|
||||
Normalize score to [0.0, 1.0] range.
|
||||
|
||||
Args:
|
||||
score: Raw score
|
||||
|
||||
Returns:
|
||||
Normalized score
|
||||
"""
|
||||
return max(0.0, min(1.0, score))
|
||||
368
backend/app/services/context/scoring/composite.py
Normal file
368
backend/app/services/context/scoring/composite.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Composite Scorer for Context Management.
|
||||
|
||||
Combines multiple scoring strategies with configurable weights.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext
|
||||
from .priority import PriorityScorer
|
||||
from .recency import RecencyScorer
|
||||
from .relevance import RelevanceScorer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredContext:
|
||||
"""Context with computed scores."""
|
||||
|
||||
context: BaseContext
|
||||
composite_score: float
|
||||
relevance_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
priority_score: float = 0.0
|
||||
|
||||
def __lt__(self, other: "ScoredContext") -> bool:
|
||||
"""Enable sorting by composite score."""
|
||||
return self.composite_score < other.composite_score
|
||||
|
||||
def __gt__(self, other: "ScoredContext") -> bool:
|
||||
"""Enable sorting by composite score."""
|
||||
return self.composite_score > other.composite_score
|
||||
|
||||
|
||||
class CompositeScorer:
|
||||
"""
|
||||
Combines multiple scoring strategies.
|
||||
|
||||
Weights:
|
||||
- relevance: How well content matches the query
|
||||
- recency: How recent the content is
|
||||
- priority: Explicit priority assignments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
relevance_weight: float | None = None,
|
||||
recency_weight: float | None = None,
|
||||
priority_weight: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize composite scorer.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP manager for semantic scoring
|
||||
settings: Context settings (uses default if None)
|
||||
relevance_weight: Override relevance weight
|
||||
recency_weight: Override recency weight
|
||||
priority_weight: Override priority weight
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
weights = self._settings.get_scoring_weights()
|
||||
|
||||
self._relevance_weight = (
|
||||
relevance_weight if relevance_weight is not None else weights["relevance"]
|
||||
)
|
||||
self._recency_weight = (
|
||||
recency_weight if recency_weight is not None else weights["recency"]
|
||||
)
|
||||
self._priority_weight = (
|
||||
priority_weight if priority_weight is not None else weights["priority"]
|
||||
)
|
||||
|
||||
# Initialize scorers
|
||||
self._relevance_scorer = RelevanceScorer(
|
||||
mcp_manager=mcp_manager,
|
||||
weight=self._relevance_weight,
|
||||
)
|
||||
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
||||
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||
|
||||
# Per-context locks to prevent race conditions during parallel scoring
|
||||
# Uses dict with (lock, last_used_time) tuples for cleanup
|
||||
self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
|
||||
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
|
||||
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
||||
|
||||
@property
|
||||
def weights(self) -> dict[str, float]:
|
||||
"""Get current scoring weights."""
|
||||
return {
|
||||
"relevance": self._relevance_weight,
|
||||
"recency": self._recency_weight,
|
||||
"priority": self._priority_weight,
|
||||
}
|
||||
|
||||
def update_weights(
|
||||
self,
|
||||
relevance: float | None = None,
|
||||
recency: float | None = None,
|
||||
priority: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update scoring weights.
|
||||
|
||||
Args:
|
||||
relevance: New relevance weight
|
||||
recency: New recency weight
|
||||
priority: New priority weight
|
||||
"""
|
||||
if relevance is not None:
|
||||
self._relevance_weight = max(0.0, min(1.0, relevance))
|
||||
self._relevance_scorer.weight = self._relevance_weight
|
||||
|
||||
if recency is not None:
|
||||
self._recency_weight = max(0.0, min(1.0, recency))
|
||||
self._recency_scorer.weight = self._recency_weight
|
||||
|
||||
if priority is not None:
|
||||
self._priority_weight = max(0.0, min(1.0, priority))
|
||||
self._priority_scorer.weight = self._priority_weight
|
||||
|
||||
async def _get_context_lock(self, context_id: str) -> asyncio.Lock:
|
||||
"""
|
||||
Get or create a lock for a specific context.
|
||||
|
||||
Thread-safe access to per-context locks prevents race conditions
|
||||
when the same context is scored concurrently. Includes automatic
|
||||
cleanup of old locks to prevent memory growth.
|
||||
|
||||
Args:
|
||||
context_id: The context ID to get a lock for
|
||||
|
||||
Returns:
|
||||
asyncio.Lock for the context
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# Fast path: check if lock exists without acquiring main lock
|
||||
# NOTE: We only READ here - no writes to avoid race conditions
|
||||
# with cleanup. The timestamp will be updated in the slow path
|
||||
# if the lock is still valid.
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Return the lock but defer timestamp update to avoid race
|
||||
# The lock is still valid; timestamp update is best-effort
|
||||
return lock
|
||||
|
||||
# Slow path: create lock or update timestamp while holding main lock
|
||||
async with self._locks_lock:
|
||||
# Double-check after acquiring lock - entry may have been
|
||||
# created by another coroutine or deleted by cleanup
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Safe to update timestamp here since we hold the lock
|
||||
self._context_locks[context_id] = (lock, now)
|
||||
return lock
|
||||
|
||||
# Cleanup old locks if we have too many
|
||||
if len(self._context_locks) >= self._max_locks:
|
||||
self._cleanup_old_locks(now)
|
||||
|
||||
# Create new lock
|
||||
new_lock = asyncio.Lock()
|
||||
self._context_locks[context_id] = (new_lock, now)
|
||||
return new_lock
|
||||
|
||||
def _cleanup_old_locks(self, now: float) -> None:
|
||||
"""
|
||||
Remove old locks that haven't been used recently.
|
||||
|
||||
Called while holding _locks_lock. Removes locks older than _lock_ttl,
|
||||
but only if they're not currently held.
|
||||
|
||||
Args:
|
||||
now: Current timestamp for age calculation
|
||||
"""
|
||||
cutoff = now - self._lock_ttl
|
||||
to_remove = []
|
||||
|
||||
for context_id, (lock, last_used) in self._context_locks.items():
|
||||
# Only remove if old AND not currently held
|
||||
if last_used < cutoff and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
# Remove oldest 50% if still over limit after TTL filtering
|
||||
if len(self._context_locks) - len(to_remove) >= self._max_locks:
|
||||
# Sort by last used time and mark oldest for removal
|
||||
sorted_entries = sorted(
|
||||
self._context_locks.items(),
|
||||
key=lambda x: x[1][1], # Sort by last_used time
|
||||
)
|
||||
# Remove oldest 50% that aren't locked
|
||||
target_remove = len(self._context_locks) // 2
|
||||
for context_id, (lock, _) in sorted_entries:
|
||||
if len(to_remove) >= target_remove:
|
||||
break
|
||||
if context_id not in to_remove and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
for context_id in to_remove:
|
||||
del self._context_locks[context_id]
|
||||
|
||||
if to_remove:
|
||||
logger.debug(f"Cleaned up {len(to_remove)} context locks")
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score for a context.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Composite score between 0.0 and 1.0
|
||||
"""
|
||||
scored = await self.score_with_details(context, query, **kwargs)
|
||||
return scored.composite_score
|
||||
|
||||
async def score_with_details(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> ScoredContext:
|
||||
"""
|
||||
Compute composite score with individual scores.
|
||||
|
||||
Uses per-context locking to prevent race conditions when the same
|
||||
context is scored concurrently in parallel scoring operations.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
ScoredContext with all scores
|
||||
"""
|
||||
# Get lock for this specific context to prevent race conditions
|
||||
# within concurrent scoring operations for the same query
|
||||
context_lock = await self._get_context_lock(context.id)
|
||||
|
||||
async with context_lock:
|
||||
# Compute individual scores in parallel
|
||||
# Note: We do NOT cache scores on the context because scores are
|
||||
# query-dependent. Caching without considering the query would
|
||||
# return incorrect scores for different queries.
|
||||
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||
|
||||
relevance_score, recency_score, priority_score = await asyncio.gather(
|
||||
relevance_task, recency_task, priority_task
|
||||
)
|
||||
|
||||
# Compute weighted composite
|
||||
total_weight = (
|
||||
self._relevance_weight + self._recency_weight + self._priority_weight
|
||||
)
|
||||
|
||||
if total_weight > 0:
|
||||
composite = (
|
||||
relevance_score * self._relevance_weight
|
||||
+ recency_score * self._recency_weight
|
||||
+ priority_score * self._priority_weight
|
||||
) / total_weight
|
||||
else:
|
||||
composite = 0.0
|
||||
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=composite,
|
||||
relevance_score=relevance_score,
|
||||
recency_score=recency_score,
|
||||
priority_score=priority_score,
|
||||
)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
parallel: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query to score against
|
||||
parallel: Whether to score in parallel
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
List of ScoredContext (same order as input)
|
||||
"""
|
||||
if parallel:
|
||||
tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts]
|
||||
return await asyncio.gather(*tasks)
|
||||
else:
|
||||
results = []
|
||||
for ctx in contexts:
|
||||
scored = await self.score_with_details(ctx, query, **kwargs)
|
||||
results.append(scored)
|
||||
return results
|
||||
|
||||
async def rank(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
min_score: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Score and rank contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
limit: Maximum number of results
|
||||
min_score: Minimum score threshold
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Sorted list of ScoredContext (highest first)
|
||||
"""
|
||||
# Score all contexts
|
||||
scored = await self.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# Filter by minimum score
|
||||
if min_score > 0:
|
||||
scored = [s for s in scored if s.composite_score >= min_score]
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored.sort(reverse=True)
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
scored = scored[:limit]
|
||||
|
||||
return scored
|
||||
135
backend/app/services/context/scoring/priority.py
Normal file
135
backend/app/services/context/scoring/priority.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Priority Scorer for Context Management.
|
||||
|
||||
Scores context based on assigned priority levels.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import BaseScorer
|
||||
|
||||
|
||||
class PriorityScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on priority levels.
|
||||
|
||||
Converts priority enum values to normalized scores.
|
||||
Also applies type-based priority bonuses.
|
||||
"""
|
||||
|
||||
# Default priority bonuses by context type
|
||||
DEFAULT_TYPE_BONUSES: ClassVar[dict[ContextType, float]] = {
|
||||
ContextType.SYSTEM: 0.2, # System prompts get a boost
|
||||
ContextType.TASK: 0.15, # Current task is important
|
||||
ContextType.TOOL: 0.1, # Recent tool results matter
|
||||
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
|
||||
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: float = 1.0,
|
||||
type_bonuses: dict[ContextType, float] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize priority scorer.
|
||||
|
||||
Args:
|
||||
weight: Scorer weight for composite scoring
|
||||
type_bonuses: Optional context-type priority bonuses
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._type_bonuses = type_bonuses or self.DEFAULT_TYPE_BONUSES.copy()
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context based on priority.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query (not used for priority, kept for interface)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Priority score between 0.0 and 1.0
|
||||
"""
|
||||
# Get base priority score
|
||||
priority_value = context.priority
|
||||
base_score = self._priority_to_score(priority_value)
|
||||
|
||||
# Apply type bonus
|
||||
context_type = context.get_type()
|
||||
bonus = self._type_bonuses.get(context_type, 0.0)
|
||||
|
||||
return self.normalize_score(base_score + bonus)
|
||||
|
||||
def _priority_to_score(self, priority: int) -> float:
|
||||
"""
|
||||
Convert priority value to normalized score.
|
||||
|
||||
Priority values (from ContextPriority):
|
||||
- CRITICAL (100) -> 1.0
|
||||
- HIGH (80) -> 0.8
|
||||
- NORMAL (50) -> 0.5
|
||||
- LOW (20) -> 0.2
|
||||
- MINIMAL (0) -> 0.0
|
||||
|
||||
Args:
|
||||
priority: Priority value (0-100)
|
||||
|
||||
Returns:
|
||||
Normalized score (0.0-1.0)
|
||||
"""
|
||||
# Clamp to valid range
|
||||
clamped = max(0, min(100, priority))
|
||||
return clamped / 100.0
|
||||
|
||||
def get_type_bonus(self, context_type: ContextType) -> float:
|
||||
"""
|
||||
Get priority bonus for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type
|
||||
|
||||
Returns:
|
||||
Bonus value
|
||||
"""
|
||||
return self._type_bonuses.get(context_type, 0.0)
|
||||
|
||||
def set_type_bonus(self, context_type: ContextType, bonus: float) -> None:
|
||||
"""
|
||||
Set priority bonus for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type
|
||||
bonus: Bonus value (0.0-1.0)
|
||||
"""
|
||||
if not 0.0 <= bonus <= 1.0:
|
||||
raise ValueError("Bonus must be between 0.0 and 1.0")
|
||||
self._type_bonuses[context_type] = bonus
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query (not used)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
# Priority scoring is fast, no async needed
|
||||
return [await self.score(ctx, query, **kwargs) for ctx in contexts]
|
||||
141
backend/app/services/context/scoring/recency.py
Normal file
141
backend/app/services/context/scoring/recency.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Recency Scorer for Context Management.
|
||||
|
||||
Scores context based on how recent it is.
|
||||
More recent content gets higher scores.
|
||||
"""
|
||||
|
||||
import math
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import BaseScorer
|
||||
|
||||
|
||||
class RecencyScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on recency.
|
||||
|
||||
Uses exponential decay to score content based on age.
|
||||
More recent content scores higher.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: float = 1.0,
|
||||
half_life_hours: float = 24.0,
|
||||
type_half_lives: dict[ContextType, float] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize recency scorer.
|
||||
|
||||
Args:
|
||||
weight: Scorer weight for composite scoring
|
||||
half_life_hours: Default hours until score decays to 0.5
|
||||
type_half_lives: Optional context-type-specific half lives
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._half_life_hours = half_life_hours
|
||||
self._type_half_lives = type_half_lives or {}
|
||||
|
||||
# Set sensible defaults for context types
|
||||
if ContextType.CONVERSATION not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.CONVERSATION] = 1.0 # 1 hour
|
||||
if ContextType.TOOL not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.TOOL] = 0.5 # 30 minutes
|
||||
if ContextType.KNOWLEDGE not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.KNOWLEDGE] = 168.0 # 1 week
|
||||
if ContextType.SYSTEM not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.SYSTEM] = 720.0 # 30 days
|
||||
if ContextType.TASK not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.TASK] = 24.0 # 1 day
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context based on recency.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query (not used for recency, kept for interface)
|
||||
**kwargs: Additional parameters
|
||||
- reference_time: Time to measure recency from (default: now)
|
||||
|
||||
Returns:
|
||||
Recency score between 0.0 and 1.0
|
||||
"""
|
||||
reference_time = kwargs.get("reference_time")
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(UTC)
|
||||
elif reference_time.tzinfo is None:
|
||||
reference_time = reference_time.replace(tzinfo=UTC)
|
||||
|
||||
# Ensure context timestamp is timezone-aware
|
||||
context_time = context.timestamp
|
||||
if context_time.tzinfo is None:
|
||||
context_time = context_time.replace(tzinfo=UTC)
|
||||
|
||||
# Calculate age in hours
|
||||
age = reference_time - context_time
|
||||
age_hours = max(0, age.total_seconds() / 3600)
|
||||
|
||||
# Get half-life for this context type
|
||||
context_type = context.get_type()
|
||||
half_life = self._type_half_lives.get(context_type, self._half_life_hours)
|
||||
|
||||
# Exponential decay
|
||||
decay_factor = math.exp(-math.log(2) * age_hours / half_life)
|
||||
|
||||
return self.normalize_score(decay_factor)
|
||||
|
||||
def get_half_life(self, context_type: ContextType) -> float:
|
||||
"""
|
||||
Get half-life for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to get half-life for
|
||||
|
||||
Returns:
|
||||
Half-life in hours
|
||||
"""
|
||||
return self._type_half_lives.get(context_type, self._half_life_hours)
|
||||
|
||||
def set_half_life(self, context_type: ContextType, hours: float) -> None:
|
||||
"""
|
||||
Set half-life for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to set half-life for
|
||||
hours: Half-life in hours
|
||||
"""
|
||||
if hours <= 0:
|
||||
raise ValueError("Half-life must be positive")
|
||||
self._type_half_lives[context_type] = hours
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query (not used)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
scores = []
|
||||
for context in contexts:
|
||||
score = await self.score(context, query, **kwargs)
|
||||
scores.append(score)
|
||||
return scores
|
||||
220
backend/app/services/context/scoring/relevance.py
Normal file
220
backend/app/services/context/scoring/relevance.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Relevance Scorer for Context Management.
|
||||
|
||||
Scores context based on semantic similarity to the query.
|
||||
Uses Knowledge Base embeddings when available.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext, KnowledgeContext
|
||||
from .base import BaseScorer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RelevanceScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on relevance to query.
|
||||
|
||||
Uses multiple strategies:
|
||||
1. Pre-computed scores (from RAG results)
|
||||
2. MCP-based semantic similarity (via Knowledge Base)
|
||||
3. Keyword matching fallback
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
weight: float = 1.0,
|
||||
keyword_fallback_weight: float | None = None,
|
||||
semantic_max_chars: int | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize relevance scorer.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP manager for Knowledge Base calls
|
||||
weight: Scorer weight for composite scoring
|
||||
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
|
||||
semantic_max_chars: Max content length for semantic similarity (overrides settings)
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._settings = settings or get_context_settings()
|
||||
self._mcp = mcp_manager
|
||||
|
||||
# Use provided values or fall back to settings
|
||||
self._keyword_fallback_weight = (
|
||||
keyword_fallback_weight
|
||||
if keyword_fallback_weight is not None
|
||||
else self._settings.relevance_keyword_fallback_weight
|
||||
)
|
||||
self._semantic_max_chars = (
|
||||
semantic_max_chars
|
||||
if semantic_max_chars is not None
|
||||
else self._settings.relevance_semantic_max_chars
|
||||
)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
self._mcp = mcp_manager
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context relevance to query.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Relevance score between 0.0 and 1.0
|
||||
"""
|
||||
# 1. Check for pre-computed relevance score
|
||||
if (
|
||||
isinstance(context, KnowledgeContext)
|
||||
and context.relevance_score is not None
|
||||
):
|
||||
return self.normalize_score(context.relevance_score)
|
||||
|
||||
# 2. Check metadata for score
|
||||
if "relevance_score" in context.metadata:
|
||||
return self.normalize_score(context.metadata["relevance_score"])
|
||||
|
||||
if "score" in context.metadata:
|
||||
return self.normalize_score(context.metadata["score"])
|
||||
|
||||
# 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
|
||||
# Note: This requires the knowledge-base MCP server to implement compute_similarity
|
||||
if self._mcp is not None:
|
||||
try:
|
||||
score = await self._compute_semantic_similarity(context, query)
|
||||
if score is not None:
|
||||
return score
|
||||
except Exception as e:
|
||||
# Log at debug level since this is expected if compute_similarity
|
||||
# tool is not implemented in the Knowledge Base server
|
||||
logger.debug(
|
||||
f"Semantic scoring unavailable, using keyword fallback: {e}"
|
||||
)
|
||||
|
||||
# 4. Fall back to keyword matching
|
||||
return self._compute_keyword_score(context, query)
|
||||
|
||||
async def _compute_semantic_similarity(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
) -> float | None:
|
||||
"""
|
||||
Compute semantic similarity using Knowledge Base embeddings.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to compare
|
||||
|
||||
Returns:
|
||||
Similarity score or None if unavailable
|
||||
"""
|
||||
if self._mcp is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use Knowledge Base's search capability to compute similarity
|
||||
result = await self._mcp.call_tool(
|
||||
server="knowledge-base",
|
||||
tool="compute_similarity",
|
||||
args={
|
||||
"text1": query,
|
||||
"text2": context.content[
|
||||
: self._semantic_max_chars
|
||||
], # Limit content length
|
||||
},
|
||||
)
|
||||
|
||||
if result.success and isinstance(result.data, dict):
|
||||
similarity = result.data.get("similarity")
|
||||
if similarity is not None:
|
||||
return self.normalize_score(float(similarity))
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Semantic similarity computation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _compute_keyword_score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
) -> float:
|
||||
"""
|
||||
Compute relevance score based on keyword matching.
|
||||
|
||||
Simple but fast fallback when semantic search is unavailable.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to match
|
||||
|
||||
Returns:
|
||||
Keyword-based relevance score
|
||||
"""
|
||||
if not query or not context.content:
|
||||
return 0.0
|
||||
|
||||
# Extract keywords from query
|
||||
query_lower = query.lower()
|
||||
content_lower = context.content.lower()
|
||||
|
||||
# Simple word tokenization
|
||||
query_words = set(re.findall(r"\b\w{3,}\b", query_lower))
|
||||
content_words = set(re.findall(r"\b\w{3,}\b", content_lower))
|
||||
|
||||
if not query_words:
|
||||
return 0.0
|
||||
|
||||
# Calculate overlap
|
||||
common_words = query_words & content_words
|
||||
overlap_ratio = len(common_words) / len(query_words)
|
||||
|
||||
# Apply fallback weight ceiling
|
||||
return self.normalize_score(overlap_ratio * self._keyword_fallback_weight)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts in parallel.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if not contexts:
|
||||
return []
|
||||
|
||||
tasks = [self.score(context, query, **kwargs) for context in contexts]
|
||||
return await asyncio.gather(*tasks)
|
||||
43
backend/app/services/context/types/__init__.py
Normal file
43
backend/app/services/context/types/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Context Types Module.
|
||||
|
||||
Provides all context types used in the Context Management Engine.
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
)
|
||||
from .conversation import (
|
||||
ConversationContext,
|
||||
MessageRole,
|
||||
)
|
||||
from .knowledge import KnowledgeContext
|
||||
from .system import SystemContext
|
||||
from .task import (
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
)
|
||||
from .tool import (
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssembledContext",
|
||||
"BaseContext",
|
||||
"ContextPriority",
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"KnowledgeContext",
|
||||
"MessageRole",
|
||||
"SystemContext",
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
]
|
||||
347
backend/app/services/context/types/base.py
Normal file
347
backend/app/services/context/types/base.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Base Context Types and Enums.
|
||||
|
||||
Provides the foundation for all context types used in
|
||||
the Context Management Engine.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class ContextType(str, Enum):
|
||||
"""
|
||||
Types of context that can be assembled.
|
||||
|
||||
Each type has specific handling, formatting, and
|
||||
budget allocation rules.
|
||||
"""
|
||||
|
||||
SYSTEM = "system"
|
||||
TASK = "task"
|
||||
KNOWLEDGE = "knowledge"
|
||||
CONVERSATION = "conversation"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "ContextType":
|
||||
"""
|
||||
Convert string to ContextType.
|
||||
|
||||
Args:
|
||||
value: String value
|
||||
|
||||
Returns:
|
||||
ContextType enum value
|
||||
|
||||
Raises:
|
||||
ValueError: If value is not a valid context type
|
||||
"""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError:
|
||||
valid = ", ".join(t.value for t in cls)
|
||||
raise ValueError(f"Invalid context type '{value}'. Valid types: {valid}")
|
||||
|
||||
|
||||
class ContextPriority(int, Enum):
|
||||
"""
|
||||
Priority levels for context ordering.
|
||||
|
||||
Higher values indicate higher priority.
|
||||
"""
|
||||
|
||||
LOWEST = 0
|
||||
LOW = 25
|
||||
NORMAL = 50
|
||||
HIGH = 75
|
||||
HIGHEST = 100
|
||||
CRITICAL = 150 # Never omit
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, value: int) -> "ContextPriority":
|
||||
"""
|
||||
Get closest priority level for an integer.
|
||||
|
||||
Args:
|
||||
value: Integer priority value
|
||||
|
||||
Returns:
|
||||
Closest ContextPriority enum value
|
||||
"""
|
||||
priorities = sorted(cls, key=lambda p: p.value)
|
||||
for priority in reversed(priorities):
|
||||
if value >= priority.value:
|
||||
return priority
|
||||
return cls.LOWEST
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class BaseContext(ABC):
|
||||
"""
|
||||
Abstract base class for all context types.
|
||||
|
||||
Provides common fields and methods for context handling,
|
||||
scoring, and serialization.
|
||||
"""
|
||||
|
||||
# Required fields
|
||||
content: str
|
||||
source: str
|
||||
|
||||
# Optional fields with defaults
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
priority: int = field(default=ContextPriority.NORMAL.value)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Computed/cached fields
|
||||
_token_count: int | None = field(default=None, repr=False)
|
||||
_score: float | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def token_count(self) -> int | None:
|
||||
"""Get cached token count (None if not counted yet)."""
|
||||
return self._token_count
|
||||
|
||||
@token_count.setter
|
||||
def token_count(self, value: int) -> None:
|
||||
"""Set token count."""
|
||||
self._token_count = value
|
||||
|
||||
@property
|
||||
def score(self) -> float | None:
|
||||
"""Get cached score (None if not scored yet)."""
|
||||
return self._score
|
||||
|
||||
@score.setter
|
||||
def score(self, value: float) -> None:
|
||||
"""Set score (clamped to 0.0-1.0)."""
|
||||
self._score = max(0.0, min(1.0, value))
|
||||
|
||||
@abstractmethod
|
||||
def get_type(self) -> ContextType:
|
||||
"""
|
||||
Get the type of this context.
|
||||
|
||||
Returns:
|
||||
ContextType enum value
|
||||
"""
|
||||
...
|
||||
|
||||
def get_age_seconds(self) -> float:
|
||||
"""
|
||||
Get age of context in seconds.
|
||||
|
||||
Returns:
|
||||
Age in seconds since creation
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
delta = now - self.timestamp
|
||||
return delta.total_seconds()
|
||||
|
||||
def get_age_hours(self) -> float:
|
||||
"""
|
||||
Get age of context in hours.
|
||||
|
||||
Returns:
|
||||
Age in hours since creation
|
||||
"""
|
||||
return self.get_age_seconds() / 3600
|
||||
|
||||
def is_stale(self, max_age_hours: float = 168.0) -> bool:
|
||||
"""
|
||||
Check if context is stale.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age before considered stale (default 7 days)
|
||||
|
||||
Returns:
|
||||
True if context is older than max_age_hours
|
||||
"""
|
||||
return self.get_age_hours() > max_age_hours
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert context to dictionary for serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary representation
|
||||
"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.get_type().value,
|
||||
"content": self.content,
|
||||
"source": self.source,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"priority": self.priority,
|
||||
"metadata": self.metadata,
|
||||
"token_count": self._token_count,
|
||||
"score": self._score,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "BaseContext":
|
||||
"""
|
||||
Create context from dictionary.
|
||||
|
||||
Note: Subclasses should override this to return correct type.
|
||||
|
||||
Args:
|
||||
data: Dictionary with context data
|
||||
|
||||
Returns:
|
||||
Context instance
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement from_dict")
|
||||
|
||||
def truncate(self, max_tokens: int, suffix: str = "... [truncated]") -> str:
|
||||
"""
|
||||
Truncate content to fit within token limit.
|
||||
|
||||
This is a rough estimation based on characters.
|
||||
For accurate truncation, use the TokenCalculator.
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens allowed
|
||||
suffix: Suffix to append when truncated
|
||||
|
||||
Returns:
|
||||
Truncated content
|
||||
"""
|
||||
if self._token_count is None or self._token_count <= max_tokens:
|
||||
return self.content
|
||||
|
||||
# Rough estimation: 4 chars per token on average
|
||||
estimated_chars = max_tokens * 4
|
||||
suffix_chars = len(suffix)
|
||||
|
||||
if len(self.content) <= estimated_chars:
|
||||
return self.content
|
||||
|
||||
truncated = self.content[: estimated_chars - suffix_chars]
|
||||
# Try to break at word boundary
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > estimated_chars * 0.8:
|
||||
truncated = truncated[:last_space]
|
||||
|
||||
return truncated + suffix
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on ID for set/dict usage."""
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on ID."""
|
||||
if not isinstance(other, BaseContext):
|
||||
return False
|
||||
return self.id == other.id
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssembledContext:
|
||||
"""
|
||||
Result of context assembly.
|
||||
|
||||
Contains the final formatted context ready for LLM consumption,
|
||||
along with metadata about the assembly process.
|
||||
"""
|
||||
|
||||
# Main content
|
||||
content: str
|
||||
total_tokens: int
|
||||
|
||||
# Assembly metadata
|
||||
context_count: int
|
||||
excluded_count: int = 0
|
||||
assembly_time_ms: float = 0.0
|
||||
model: str = ""
|
||||
|
||||
# Included contexts (optional - for inspection)
|
||||
contexts: list["BaseContext"] = field(default_factory=list)
|
||||
|
||||
# Additional metadata from assembly
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Budget tracking
|
||||
budget_total: int = 0
|
||||
budget_used: int = 0
|
||||
|
||||
# Context breakdown
|
||||
by_type: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# Cache info
|
||||
cache_hit: bool = False
|
||||
cache_key: str | None = None
|
||||
|
||||
# Aliases for backward compatibility
|
||||
@property
|
||||
def token_count(self) -> int:
|
||||
"""Alias for total_tokens."""
|
||||
return self.total_tokens
|
||||
|
||||
@property
|
||||
def contexts_included(self) -> int:
|
||||
"""Alias for context_count."""
|
||||
return self.context_count
|
||||
|
||||
@property
|
||||
def contexts_excluded(self) -> int:
|
||||
"""Alias for excluded_count."""
|
||||
return self.excluded_count
|
||||
|
||||
@property
|
||||
def budget_utilization(self) -> float:
|
||||
"""Get budget utilization percentage."""
|
||||
if self.budget_total == 0:
|
||||
return 0.0
|
||||
return self.budget_used / self.budget_total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"content": self.content,
|
||||
"total_tokens": self.total_tokens,
|
||||
"context_count": self.context_count,
|
||||
"excluded_count": self.excluded_count,
|
||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||
"model": self.model,
|
||||
"metadata": self.metadata,
|
||||
"budget_total": self.budget_total,
|
||||
"budget_used": self.budget_used,
|
||||
"budget_utilization": round(self.budget_utilization, 3),
|
||||
"by_type": self.by_type,
|
||||
"cache_hit": self.cache_hit,
|
||||
"cache_key": self.cache_key,
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert to JSON string."""
|
||||
import json
|
||||
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "AssembledContext":
|
||||
"""Create from JSON string."""
|
||||
import json
|
||||
|
||||
data = json.loads(json_str)
|
||||
return cls(
|
||||
content=data["content"],
|
||||
total_tokens=data["total_tokens"],
|
||||
context_count=data["context_count"],
|
||||
excluded_count=data.get("excluded_count", 0),
|
||||
assembly_time_ms=data.get("assembly_time_ms", 0.0),
|
||||
model=data.get("model", ""),
|
||||
metadata=data.get("metadata", {}),
|
||||
budget_total=data.get("budget_total", 0),
|
||||
budget_used=data.get("budget_used", 0),
|
||||
by_type=data.get("by_type", {}),
|
||||
cache_hit=data.get("cache_hit", False),
|
||||
cache_key=data.get("cache_key"),
|
||||
)
|
||||
182
backend/app/services/context/types/conversation.py
Normal file
182
backend/app/services/context/types/conversation.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Conversation Context Type.
|
||||
|
||||
Represents conversation history for context continuity.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Roles for conversation messages."""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "MessageRole":
|
||||
"""Convert string to MessageRole."""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError:
|
||||
# Default to user for unknown roles
|
||||
return cls.USER
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ConversationContext(BaseContext):
|
||||
"""
|
||||
Context from conversation history.
|
||||
|
||||
Represents a single turn in the conversation,
|
||||
including user messages, assistant responses,
|
||||
and tool results.
|
||||
"""
|
||||
|
||||
# Conversation-specific fields
|
||||
role: MessageRole = field(default=MessageRole.USER)
|
||||
turn_index: int = field(default=0)
|
||||
session_id: str | None = field(default=None)
|
||||
parent_message_id: str | None = field(default=None)
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return CONVERSATION context type."""
|
||||
return ContextType.CONVERSATION
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with conversation-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"role": self.role.value,
|
||||
"turn_index": self.turn_index,
|
||||
"session_id": self.session_id,
|
||||
"parent_message_id": self.parent_message_id,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ConversationContext":
|
||||
"""Create ConversationContext from dictionary."""
|
||||
role = data.get("role", "user")
|
||||
if isinstance(role, str):
|
||||
role = MessageRole.from_string(role)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "conversation"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
role=role,
|
||||
turn_index=data.get("turn_index", 0),
|
||||
session_id=data.get("session_id"),
|
||||
parent_message_id=data.get("parent_message_id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_message(
|
||||
cls,
|
||||
content: str,
|
||||
role: str | MessageRole,
|
||||
turn_index: int = 0,
|
||||
session_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> "ConversationContext":
|
||||
"""
|
||||
Create ConversationContext from a message.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
role: Message role (user, assistant, system, tool)
|
||||
turn_index: Position in conversation
|
||||
session_id: Session identifier
|
||||
timestamp: Message timestamp
|
||||
|
||||
Returns:
|
||||
ConversationContext instance
|
||||
"""
|
||||
if isinstance(role, str):
|
||||
role = MessageRole.from_string(role)
|
||||
|
||||
# Recent messages have higher priority
|
||||
priority = ContextPriority.NORMAL.value
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source="conversation",
|
||||
role=role,
|
||||
turn_index=turn_index,
|
||||
session_id=session_id,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_history(
|
||||
cls,
|
||||
messages: list[dict[str, Any]],
|
||||
session_id: str | None = None,
|
||||
) -> list["ConversationContext"]:
|
||||
"""
|
||||
Create multiple ConversationContexts from message history.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
List of ConversationContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for i, msg in enumerate(messages):
|
||||
ctx = cls.from_message(
|
||||
content=msg.get("content", ""),
|
||||
role=msg.get("role", "user"),
|
||||
turn_index=i,
|
||||
session_id=session_id,
|
||||
timestamp=datetime.fromisoformat(msg["timestamp"])
|
||||
if "timestamp" in msg
|
||||
else None,
|
||||
)
|
||||
contexts.append(ctx)
|
||||
return contexts
|
||||
|
||||
def is_user_message(self) -> bool:
|
||||
"""Check if this is a user message."""
|
||||
return self.role == MessageRole.USER
|
||||
|
||||
def is_assistant_message(self) -> bool:
|
||||
"""Check if this is an assistant message."""
|
||||
return self.role == MessageRole.ASSISTANT
|
||||
|
||||
def is_tool_result(self) -> bool:
|
||||
"""Check if this is a tool result."""
|
||||
return self.role == MessageRole.TOOL
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format message for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted message string
|
||||
"""
|
||||
role_labels = {
|
||||
MessageRole.USER: "User",
|
||||
MessageRole.ASSISTANT: "Assistant",
|
||||
MessageRole.SYSTEM: "System",
|
||||
MessageRole.TOOL: "Tool Result",
|
||||
}
|
||||
label = role_labels.get(self.role, "Unknown")
|
||||
return f"{label}: {self.content}"
|
||||
152
backend/app/services/context/types/knowledge.py
Normal file
152
backend/app/services/context/types/knowledge.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Knowledge Context Type.
|
||||
|
||||
Represents RAG results from the Knowledge Base MCP server.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class KnowledgeContext(BaseContext):
|
||||
"""
|
||||
Context from knowledge base / RAG retrieval.
|
||||
|
||||
Knowledge context represents chunks retrieved from the
|
||||
Knowledge Base MCP server, including:
|
||||
- Code snippets
|
||||
- Documentation
|
||||
- Previous conversations
|
||||
- External knowledge
|
||||
|
||||
Each chunk includes relevance scoring from the search.
|
||||
"""
|
||||
|
||||
# Knowledge-specific fields
|
||||
collection: str = field(default="default")
|
||||
file_type: str | None = field(default=None)
|
||||
chunk_index: int = field(default=0)
|
||||
relevance_score: float = field(default=0.0)
|
||||
search_query: str = field(default="")
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return KNOWLEDGE context type."""
|
||||
return ContextType.KNOWLEDGE
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with knowledge-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"collection": self.collection,
|
||||
"file_type": self.file_type,
|
||||
"chunk_index": self.chunk_index,
|
||||
"relevance_score": self.relevance_score,
|
||||
"search_query": self.search_query,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "KnowledgeContext":
|
||||
"""Create KnowledgeContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
collection=data.get("collection", "default"),
|
||||
file_type=data.get("file_type"),
|
||||
chunk_index=data.get("chunk_index", 0),
|
||||
relevance_score=data.get("relevance_score", 0.0),
|
||||
search_query=data.get("search_query", ""),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_search_result(
|
||||
cls,
|
||||
result: dict[str, Any],
|
||||
query: str,
|
||||
) -> "KnowledgeContext":
|
||||
"""
|
||||
Create KnowledgeContext from a Knowledge Base search result.
|
||||
|
||||
Args:
|
||||
result: Search result from Knowledge Base MCP
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
KnowledgeContext instance
|
||||
"""
|
||||
return cls(
|
||||
content=result.get("content", ""),
|
||||
source=result.get("source_path", "unknown"),
|
||||
collection=result.get("collection", "default"),
|
||||
file_type=result.get("file_type"),
|
||||
chunk_index=result.get("chunk_index", 0),
|
||||
relevance_score=result.get("score", 0.0),
|
||||
search_query=query,
|
||||
metadata={
|
||||
"chunk_id": result.get("id"),
|
||||
"content_hash": result.get("content_hash"),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_search_results(
|
||||
cls,
|
||||
results: list[dict[str, Any]],
|
||||
query: str,
|
||||
) -> list["KnowledgeContext"]:
|
||||
"""
|
||||
Create multiple KnowledgeContexts from search results.
|
||||
|
||||
Args:
|
||||
results: List of search results
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
List of KnowledgeContext instances
|
||||
"""
|
||||
return [cls.from_search_result(r, query) for r in results]
|
||||
|
||||
def is_code(self) -> bool:
|
||||
"""Check if this is code content."""
|
||||
code_types = {
|
||||
"python",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"go",
|
||||
"rust",
|
||||
"java",
|
||||
"c",
|
||||
"cpp",
|
||||
}
|
||||
return self.file_type is not None and self.file_type.lower() in code_types
|
||||
|
||||
def is_documentation(self) -> bool:
|
||||
"""Check if this is documentation content."""
|
||||
doc_types = {"markdown", "rst", "txt", "md"}
|
||||
return self.file_type is not None and self.file_type.lower() in doc_types
|
||||
|
||||
def get_formatted_source(self) -> str:
|
||||
"""
|
||||
Get a formatted source string for display.
|
||||
|
||||
Returns:
|
||||
Formatted source string
|
||||
"""
|
||||
parts = [self.source]
|
||||
if self.file_type:
|
||||
parts.append(f"({self.file_type})")
|
||||
if self.collection != "default":
|
||||
parts.insert(0, f"[{self.collection}]")
|
||||
return " ".join(parts)
|
||||
138
backend/app/services/context/types/system.py
Normal file
138
backend/app/services/context/types/system.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
System Context Type.
|
||||
|
||||
Represents system prompts, instructions, and agent personas.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class SystemContext(BaseContext):
|
||||
"""
|
||||
Context for system prompts and instructions.
|
||||
|
||||
System context typically includes:
|
||||
- Agent persona and role definitions
|
||||
- Behavioral instructions
|
||||
- Safety guidelines
|
||||
- Output format requirements
|
||||
|
||||
System context is usually high priority and should
|
||||
rarely be truncated or omitted.
|
||||
"""
|
||||
|
||||
# System context specific fields
|
||||
role: str = field(default="assistant")
|
||||
instructions_type: str = field(default="general")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set high priority for system context."""
|
||||
# System context defaults to high priority
|
||||
if self.priority == ContextPriority.NORMAL.value:
|
||||
self.priority = ContextPriority.HIGH.value
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return SYSTEM context type."""
|
||||
return ContextType.SYSTEM
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with system-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"role": self.role,
|
||||
"instructions_type": self.instructions_type,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SystemContext":
|
||||
"""Create SystemContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
role=data.get("role", "assistant"),
|
||||
instructions_type=data.get("instructions_type", "general"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_persona(
|
||||
cls,
|
||||
name: str,
|
||||
description: str,
|
||||
capabilities: list[str] | None = None,
|
||||
constraints: list[str] | None = None,
|
||||
) -> "SystemContext":
|
||||
"""
|
||||
Create a persona system context.
|
||||
|
||||
Args:
|
||||
name: Agent name/role
|
||||
description: Role description
|
||||
capabilities: List of things the agent can do
|
||||
constraints: List of limitations
|
||||
|
||||
Returns:
|
||||
SystemContext with formatted persona
|
||||
"""
|
||||
parts = [f"You are {name}.", "", description]
|
||||
|
||||
if capabilities:
|
||||
parts.append("")
|
||||
parts.append("You can:")
|
||||
for cap in capabilities:
|
||||
parts.append(f"- {cap}")
|
||||
|
||||
if constraints:
|
||||
parts.append("")
|
||||
parts.append("You must not:")
|
||||
for constraint in constraints:
|
||||
parts.append(f"- {constraint}")
|
||||
|
||||
return cls(
|
||||
content="\n".join(parts),
|
||||
source="persona_builder",
|
||||
role=name.lower().replace(" ", "_"),
|
||||
instructions_type="persona",
|
||||
priority=ContextPriority.HIGHEST.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_instructions(
|
||||
cls,
|
||||
instructions: str | list[str],
|
||||
source: str = "instructions",
|
||||
) -> "SystemContext":
|
||||
"""
|
||||
Create an instructions system context.
|
||||
|
||||
Args:
|
||||
instructions: Instructions string or list of instruction strings
|
||||
source: Source identifier
|
||||
|
||||
Returns:
|
||||
SystemContext with instructions
|
||||
"""
|
||||
if isinstance(instructions, list):
|
||||
content = "\n".join(f"- {inst}" for inst in instructions)
|
||||
else:
|
||||
content = instructions
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=source,
|
||||
instructions_type="instructions",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
193
backend/app/services/context/types/task.py
Normal file
193
backend/app/services/context/types/task.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Task Context Type.
|
||||
|
||||
Represents the current task or objective for the agent.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Status of a task."""
|
||||
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
BLOCKED = "blocked"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TaskComplexity(str, Enum):
|
||||
"""Complexity level of a task."""
|
||||
|
||||
TRIVIAL = "trivial"
|
||||
SIMPLE = "simple"
|
||||
MODERATE = "moderate"
|
||||
COMPLEX = "complex"
|
||||
VERY_COMPLEX = "very_complex"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TaskContext(BaseContext):
|
||||
"""
|
||||
Context for the current task or objective.
|
||||
|
||||
Task context provides information about what the agent
|
||||
should accomplish, including:
|
||||
- Task description and goals
|
||||
- Acceptance criteria
|
||||
- Constraints and requirements
|
||||
- Related issue/ticket information
|
||||
"""
|
||||
|
||||
# Task-specific fields
|
||||
title: str = field(default="")
|
||||
status: TaskStatus = field(default=TaskStatus.PENDING)
|
||||
complexity: TaskComplexity = field(default=TaskComplexity.MODERATE)
|
||||
issue_id: str | None = field(default=None)
|
||||
project_id: str | None = field(default=None)
|
||||
acceptance_criteria: list[str] = field(default_factory=list)
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
parent_task_id: str | None = field(default=None)
|
||||
|
||||
# Note: TaskContext should typically have HIGH priority,
|
||||
# but we don't auto-promote to allow explicit priority setting.
|
||||
# Use TaskContext.create() for default HIGH priority behavior.
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return TASK context type."""
|
||||
return ContextType.TASK
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with task-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"title": self.title,
|
||||
"status": self.status.value,
|
||||
"complexity": self.complexity.value,
|
||||
"issue_id": self.issue_id,
|
||||
"project_id": self.project_id,
|
||||
"acceptance_criteria": self.acceptance_criteria,
|
||||
"constraints": self.constraints,
|
||||
"parent_task_id": self.parent_task_id,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TaskContext":
|
||||
"""Create TaskContext from dictionary."""
|
||||
status = data.get("status", "pending")
|
||||
if isinstance(status, str):
|
||||
status = TaskStatus(status)
|
||||
|
||||
complexity = data.get("complexity", "moderate")
|
||||
if isinstance(complexity, str):
|
||||
complexity = TaskComplexity(complexity)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "task"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
title=data.get("title", ""),
|
||||
status=status,
|
||||
complexity=complexity,
|
||||
issue_id=data.get("issue_id"),
|
||||
project_id=data.get("project_id"),
|
||||
acceptance_criteria=data.get("acceptance_criteria", []),
|
||||
constraints=data.get("constraints", []),
|
||||
parent_task_id=data.get("parent_task_id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
title: str,
|
||||
description: str,
|
||||
acceptance_criteria: list[str] | None = None,
|
||||
constraints: list[str] | None = None,
|
||||
issue_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
complexity: TaskComplexity | str = TaskComplexity.MODERATE,
|
||||
) -> "TaskContext":
|
||||
"""
|
||||
Create a task context.
|
||||
|
||||
Args:
|
||||
title: Task title
|
||||
description: Task description
|
||||
acceptance_criteria: List of acceptance criteria
|
||||
constraints: List of constraints
|
||||
issue_id: Related issue ID
|
||||
project_id: Project ID
|
||||
complexity: Task complexity
|
||||
|
||||
Returns:
|
||||
TaskContext instance
|
||||
"""
|
||||
if isinstance(complexity, str):
|
||||
complexity = TaskComplexity(complexity)
|
||||
|
||||
return cls(
|
||||
content=description,
|
||||
source=f"task:{issue_id}" if issue_id else "task",
|
||||
title=title,
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
complexity=complexity,
|
||||
issue_id=issue_id,
|
||||
project_id=project_id,
|
||||
acceptance_criteria=acceptance_criteria or [],
|
||||
constraints=constraints or [],
|
||||
)
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format task for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted task string
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if self.title:
|
||||
parts.append(f"Task: {self.title}")
|
||||
parts.append("")
|
||||
|
||||
parts.append(self.content)
|
||||
|
||||
if self.acceptance_criteria:
|
||||
parts.append("")
|
||||
parts.append("Acceptance Criteria:")
|
||||
for criterion in self.acceptance_criteria:
|
||||
parts.append(f"- {criterion}")
|
||||
|
||||
if self.constraints:
|
||||
parts.append("")
|
||||
parts.append("Constraints:")
|
||||
for constraint in self.constraints:
|
||||
parts.append(f"- {constraint}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if task is currently active."""
|
||||
return self.status in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if task is complete."""
|
||||
return self.status == TaskStatus.COMPLETED
|
||||
|
||||
def is_blocked(self) -> bool:
|
||||
"""Check if task is blocked."""
|
||||
return self.status == TaskStatus.BLOCKED
|
||||
211
backend/app/services/context/types/tool.py
Normal file
211
backend/app/services/context/types/tool.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Tool Context Type.
|
||||
|
||||
Represents available tools and recent tool execution results.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class ToolResultStatus(str, Enum):
|
||||
"""Status of a tool execution result."""
|
||||
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
TIMEOUT = "timeout"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ToolContext(BaseContext):
|
||||
"""
|
||||
Context for tools and tool execution results.
|
||||
|
||||
Tool context includes:
|
||||
- Tool descriptions and parameters
|
||||
- Recent tool execution results
|
||||
- Tool availability information
|
||||
|
||||
This helps the LLM understand what tools are available
|
||||
and what results previous tool calls produced.
|
||||
"""
|
||||
|
||||
# Tool-specific fields
|
||||
tool_name: str = field(default="")
|
||||
tool_description: str = field(default="")
|
||||
is_result: bool = field(default=False)
|
||||
result_status: ToolResultStatus | None = field(default=None)
|
||||
execution_time_ms: float | None = field(default=None)
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
server_name: str | None = field(default=None)
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return TOOL context type."""
|
||||
return ContextType.TOOL
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with tool-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"tool_name": self.tool_name,
|
||||
"tool_description": self.tool_description,
|
||||
"is_result": self.is_result,
|
||||
"result_status": self.result_status.value
|
||||
if self.result_status
|
||||
else None,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"parameters": self.parameters,
|
||||
"server_name": self.server_name,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ToolContext":
|
||||
"""Create ToolContext from dictionary."""
|
||||
result_status = data.get("result_status")
|
||||
if isinstance(result_status, str):
|
||||
result_status = ToolResultStatus(result_status)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "tool"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_description=data.get("tool_description", ""),
|
||||
is_result=data.get("is_result", False),
|
||||
result_status=result_status,
|
||||
execution_time_ms=data.get("execution_time_ms"),
|
||||
parameters=data.get("parameters", {}),
|
||||
server_name=data.get("server_name"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_definition(
|
||||
cls,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
server_name: str | None = None,
|
||||
) -> "ToolContext":
|
||||
"""
|
||||
Create a ToolContext from a tool definition.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
description: Tool description
|
||||
parameters: Tool parameter schema
|
||||
server_name: MCP server name
|
||||
|
||||
Returns:
|
||||
ToolContext instance
|
||||
"""
|
||||
# Format content as tool documentation
|
||||
content_parts = [f"Tool: {name}", "", description]
|
||||
|
||||
if parameters:
|
||||
content_parts.append("")
|
||||
content_parts.append("Parameters:")
|
||||
for param_name, param_info in parameters.items():
|
||||
param_type = param_info.get("type", "any")
|
||||
param_desc = param_info.get("description", "")
|
||||
required = param_info.get("required", False)
|
||||
req_marker = " (required)" if required else ""
|
||||
content_parts.append(f" - {param_name}: {param_type}{req_marker}")
|
||||
if param_desc:
|
||||
content_parts.append(f" {param_desc}")
|
||||
|
||||
return cls(
|
||||
content="\n".join(content_parts),
|
||||
source=f"tool:{server_name}:{name}" if server_name else f"tool:{name}",
|
||||
tool_name=name,
|
||||
tool_description=description,
|
||||
is_result=False,
|
||||
parameters=parameters or {},
|
||||
server_name=server_name,
|
||||
priority=ContextPriority.LOW.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_result(
|
||||
cls,
|
||||
tool_name: str,
|
||||
result: Any,
|
||||
status: ToolResultStatus = ToolResultStatus.SUCCESS,
|
||||
execution_time_ms: float | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
server_name: str | None = None,
|
||||
) -> "ToolContext":
|
||||
"""
|
||||
Create a ToolContext from a tool execution result.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was executed
|
||||
result: Result content (will be converted to string)
|
||||
status: Execution status
|
||||
execution_time_ms: Execution time in milliseconds
|
||||
parameters: Parameters that were passed to the tool
|
||||
server_name: MCP server name
|
||||
|
||||
Returns:
|
||||
ToolContext instance
|
||||
"""
|
||||
# Convert result to string content
|
||||
if isinstance(result, str):
|
||||
content = result
|
||||
elif isinstance(result, dict):
|
||||
import json
|
||||
|
||||
try:
|
||||
content = json.dumps(result, indent=2)
|
||||
except (TypeError, ValueError):
|
||||
content = str(result)
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=f"tool_result:{server_name}:{tool_name}"
|
||||
if server_name
|
||||
else f"tool_result:{tool_name}",
|
||||
tool_name=tool_name,
|
||||
is_result=True,
|
||||
result_status=status,
|
||||
execution_time_ms=execution_time_ms,
|
||||
parameters=parameters or {},
|
||||
server_name=server_name,
|
||||
priority=ContextPriority.HIGH.value, # Recent results are high priority
|
||||
)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if this is a successful tool result."""
|
||||
return self.is_result and self.result_status == ToolResultStatus.SUCCESS
|
||||
|
||||
def is_error(self) -> bool:
|
||||
"""Check if this is an error result."""
|
||||
return self.is_result and self.result_status == ToolResultStatus.ERROR
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format tool context for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted tool string
|
||||
"""
|
||||
if self.is_result:
|
||||
status_str = self.result_status.value if self.result_status else "unknown"
|
||||
header = f"Tool Result ({self.tool_name}, {status_str}):"
|
||||
return f"{header}\n{self.content}"
|
||||
else:
|
||||
return self.content
|
||||
@@ -411,7 +411,20 @@ async def shutdown_mcp_client() -> None:
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
def reset_mcp_client() -> None:
|
||||
"""Reset the global MCP client manager (for testing)."""
|
||||
async def reset_mcp_client() -> None:
|
||||
"""
|
||||
Reset the global MCP client manager (for testing).
|
||||
|
||||
This is an async function to properly acquire the manager lock
|
||||
and avoid race conditions with get_mcp_client().
|
||||
"""
|
||||
global _manager_instance
|
||||
_manager_instance = None
|
||||
|
||||
async with _manager_lock:
|
||||
if _manager_instance is not None:
|
||||
# Shutdown gracefully before resetting
|
||||
try:
|
||||
await _manager_instance.shutdown()
|
||||
except Exception: # noqa: S110
|
||||
pass # Ignore errors during test cleanup
|
||||
_manager_instance = None
|
||||
|
||||
@@ -158,9 +158,7 @@ class MCPConfig(BaseModel):
|
||||
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return {
|
||||
name: config
|
||||
for name, config in self.mcp_servers.items()
|
||||
if config.enabled
|
||||
name: config for name, config in self.mcp_servers.items() if config.enabled
|
||||
}
|
||||
|
||||
def list_server_names(self) -> list[str]:
|
||||
|
||||
@@ -161,7 +161,7 @@ class MCPConnection:
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=e,
|
||||
)
|
||||
) from e
|
||||
else:
|
||||
# For STDIO and SSE transports, we'll implement later
|
||||
raise NotImplementedError(
|
||||
@@ -297,13 +297,13 @@ class MCPConnection:
|
||||
server_name=self.server_name,
|
||||
url=f"{self.config.url}{path}",
|
||||
cause=e,
|
||||
)
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise MCPConnectionError(
|
||||
f"Request failed: {e}",
|
||||
server_name=self.server_name,
|
||||
cause=e,
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
@@ -322,8 +322,19 @@ class ConnectionPool:
|
||||
"""
|
||||
self._connections: dict[str, MCPConnection] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._per_server_locks: dict[str, asyncio.Lock] = {}
|
||||
self._max_per_server = max_connections_per_server
|
||||
|
||||
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific server.
|
||||
|
||||
Uses setdefault for atomic dict access to prevent race conditions
|
||||
where two coroutines could create different locks for the same server.
|
||||
"""
|
||||
# setdefault is atomic - if key exists, returns existing value
|
||||
# if key doesn't exist, inserts new value and returns it
|
||||
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
|
||||
|
||||
async def get_connection(
|
||||
self,
|
||||
server_name: str,
|
||||
@@ -332,6 +343,9 @@ class ConnectionPool:
|
||||
"""
|
||||
Get or create a connection to a server.
|
||||
|
||||
Uses per-server locking to avoid blocking all connections
|
||||
when establishing a new connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
config: Server configuration
|
||||
@@ -339,17 +353,33 @@ class ConnectionPool:
|
||||
Returns:
|
||||
Active connection
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name not in self._connections:
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
self._connections[server_name] = connection
|
||||
|
||||
# Quick check without lock - if connection exists and is connected, return it
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
|
||||
# Reconnect if not connected
|
||||
if not connection.is_connected:
|
||||
# Need to create or reconnect - use per-server lock to avoid blocking others
|
||||
async with self._lock:
|
||||
server_lock = self._get_server_lock(server_name)
|
||||
|
||||
async with server_lock:
|
||||
# Double-check after acquiring per-server lock
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
# Connection exists but not connected - reconnect
|
||||
await connection.connect()
|
||||
return connection
|
||||
|
||||
# Create new connection (outside global lock, under per-server lock)
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
|
||||
# Store connection under global lock
|
||||
async with self._lock:
|
||||
self._connections[server_name] = connection
|
||||
|
||||
return connection
|
||||
|
||||
@@ -374,6 +404,9 @@ class ConnectionPool:
|
||||
if server_name in self._connections:
|
||||
await self._connections[server_name].disconnect()
|
||||
del self._connections[server_name]
|
||||
# Clean up per-server lock
|
||||
if server_name in self._per_server_locks:
|
||||
del self._per_server_locks[server_name]
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
@@ -385,6 +418,7 @@ class ConnectionPool:
|
||||
logger.warning("Error closing connection: %s", e)
|
||||
|
||||
self._connections.clear()
|
||||
self._per_server_locks.clear()
|
||||
logger.info("Closed all MCP connections")
|
||||
|
||||
async def health_check_all(self) -> dict[str, bool]:
|
||||
@@ -394,8 +428,12 @@ class ConnectionPool:
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
# Copy connections under lock to prevent modification during iteration
|
||||
async with self._lock:
|
||||
connections_snapshot = dict(self._connections)
|
||||
|
||||
results = {}
|
||||
for name, connection in self._connections.items():
|
||||
for name, connection in connections_snapshot.items():
|
||||
results[name] = await connection.health_check()
|
||||
return results
|
||||
|
||||
|
||||
170
backend/app/services/safety/__init__.py
Normal file
170
backend/app/services/safety/__init__.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Safety and Guardrails Framework
|
||||
|
||||
Comprehensive safety framework for autonomous agent operation.
|
||||
Provides multi-layered protection including:
|
||||
- Pre-execution validation
|
||||
- Cost and budget controls
|
||||
- Rate limiting
|
||||
- Loop detection and prevention
|
||||
- Human-in-the-loop approval
|
||||
- Rollback and checkpointing
|
||||
- Content filtering
|
||||
- Sandboxed execution
|
||||
- Emergency controls
|
||||
- Complete audit trail
|
||||
|
||||
Usage:
|
||||
from app.services.safety import get_safety_guardian, SafetyGuardian
|
||||
|
||||
guardian = await get_safety_guardian()
|
||||
result = await guardian.validate(action_request)
|
||||
|
||||
if result.allowed:
|
||||
# Execute action
|
||||
pass
|
||||
else:
|
||||
# Handle denial
|
||||
print(f"Action denied: {result.reasons}")
|
||||
"""
|
||||
|
||||
# Exceptions
|
||||
# Audit
|
||||
from .audit import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
AutonomyConfig,
|
||||
SafetyConfig,
|
||||
get_autonomy_config,
|
||||
get_default_policy,
|
||||
get_policy_for_autonomy_level,
|
||||
get_safety_config,
|
||||
load_policies_from_directory,
|
||||
load_policy_from_file,
|
||||
reset_config_cache,
|
||||
)
|
||||
from .exceptions import (
|
||||
ApprovalDeniedError,
|
||||
ApprovalRequiredError,
|
||||
ApprovalTimeoutError,
|
||||
BudgetExceededError,
|
||||
CheckpointError,
|
||||
ContentFilterError,
|
||||
EmergencyStopError,
|
||||
LoopDetectedError,
|
||||
PermissionDeniedError,
|
||||
PolicyViolationError,
|
||||
RateLimitExceededError,
|
||||
RollbackError,
|
||||
SafetyError,
|
||||
SandboxError,
|
||||
SandboxTimeoutError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
# Guardian
|
||||
from .guardian import (
|
||||
SafetyGuardian,
|
||||
get_safety_guardian,
|
||||
reset_safety_guardian,
|
||||
shutdown_safety_guardian,
|
||||
)
|
||||
|
||||
# Models
|
||||
from .models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
ApprovalRequest,
|
||||
ApprovalResponse,
|
||||
ApprovalStatus,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
AutonomyLevel,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
GuardianResult,
|
||||
PermissionLevel,
|
||||
RateLimitConfig,
|
||||
RateLimitStatus,
|
||||
ResourceType,
|
||||
RollbackResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationResult,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionMetadata",
|
||||
"ActionRequest",
|
||||
"ActionResult",
|
||||
# Models
|
||||
"ActionType",
|
||||
"ApprovalDeniedError",
|
||||
"ApprovalRequest",
|
||||
"ApprovalRequiredError",
|
||||
"ApprovalResponse",
|
||||
"ApprovalStatus",
|
||||
"ApprovalTimeoutError",
|
||||
"AuditEvent",
|
||||
"AuditEventType",
|
||||
# Audit
|
||||
"AuditLogger",
|
||||
"AutonomyConfig",
|
||||
"AutonomyLevel",
|
||||
"BudgetExceededError",
|
||||
"BudgetScope",
|
||||
"BudgetStatus",
|
||||
"Checkpoint",
|
||||
"CheckpointError",
|
||||
"CheckpointType",
|
||||
"ContentFilterError",
|
||||
"EmergencyStopError",
|
||||
"GuardianResult",
|
||||
"LoopDetectedError",
|
||||
"PermissionDeniedError",
|
||||
"PermissionLevel",
|
||||
"PolicyViolationError",
|
||||
"RateLimitConfig",
|
||||
"RateLimitExceededError",
|
||||
"RateLimitStatus",
|
||||
"ResourceType",
|
||||
"RollbackError",
|
||||
"RollbackResult",
|
||||
# Configuration
|
||||
"SafetyConfig",
|
||||
"SafetyDecision",
|
||||
# Exceptions
|
||||
"SafetyError",
|
||||
# Guardian
|
||||
"SafetyGuardian",
|
||||
"SafetyPolicy",
|
||||
"SandboxError",
|
||||
"SandboxTimeoutError",
|
||||
"ValidationError",
|
||||
"ValidationResult",
|
||||
"ValidationRule",
|
||||
"get_audit_logger",
|
||||
"get_autonomy_config",
|
||||
"get_default_policy",
|
||||
"get_policy_for_autonomy_level",
|
||||
"get_safety_config",
|
||||
"get_safety_guardian",
|
||||
"load_policies_from_directory",
|
||||
"load_policy_from_file",
|
||||
"reset_audit_logger",
|
||||
"reset_config_cache",
|
||||
"reset_safety_guardian",
|
||||
"shutdown_audit_logger",
|
||||
"shutdown_safety_guardian",
|
||||
]
|
||||
19
backend/app/services/safety/audit/__init__.py
Normal file
19
backend/app/services/safety/audit/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Audit System
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
"""
|
||||
|
||||
from .logger import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AuditLogger",
|
||||
"get_audit_logger",
|
||||
"reset_audit_logger",
|
||||
"shutdown_audit_logger",
|
||||
]
|
||||
601
backend/app/services/safety/audit/logger.py
Normal file
601
backend/app/services/safety/audit/logger.py
Normal file
@@ -0,0 +1,601 @@
|
||||
"""
|
||||
Audit Logger
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
Provides tamper detection, structured logging, and compliance support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sentinel for distinguishing "no argument passed" from "explicitly passing None"
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""
|
||||
Audit logger for safety events.
|
||||
|
||||
Features:
|
||||
- Structured event logging
|
||||
- In-memory buffer with async flush
|
||||
- Tamper detection via hash chains
|
||||
- Query/search capability
|
||||
- Retention policy enforcement
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_buffer_size: int = 1000,
|
||||
flush_interval_seconds: float = 10.0,
|
||||
enable_hash_chain: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the audit logger.
|
||||
|
||||
Args:
|
||||
max_buffer_size: Maximum events to buffer before auto-flush
|
||||
flush_interval_seconds: Interval for periodic flush
|
||||
enable_hash_chain: Enable tamper detection via hash chain
|
||||
"""
|
||||
self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size)
|
||||
self._persisted: list[AuditEvent] = []
|
||||
self._flush_interval = flush_interval_seconds
|
||||
self._enable_hash_chain = enable_hash_chain
|
||||
self._last_hash: str | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
# Event handlers for real-time processing
|
||||
self._handlers: list[Any] = []
|
||||
|
||||
config = get_safety_config()
|
||||
self._retention_days = config.audit_retention_days
|
||||
self._include_sensitive = config.audit_include_sensitive
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the audit logger background tasks."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
logger.info("Audit logger started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the audit logger and flush remaining events."""
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Final flush
|
||||
await self.flush()
|
||||
logger.info("Audit logger stopped")
|
||||
|
||||
async def log(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
*,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
decision: SafetyDecision | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""
|
||||
Log an audit event.
|
||||
|
||||
Args:
|
||||
event_type: Type of audit event
|
||||
agent_id: Agent ID if applicable
|
||||
action_id: Action ID if applicable
|
||||
project_id: Project ID if applicable
|
||||
session_id: Session ID if applicable
|
||||
user_id: User ID if applicable
|
||||
decision: Safety decision if applicable
|
||||
details: Additional event details
|
||||
correlation_id: Correlation ID for tracing
|
||||
|
||||
Returns:
|
||||
The created audit event
|
||||
"""
|
||||
# Sanitize sensitive data if needed
|
||||
sanitized_details = self._sanitize_details(details) if details else {}
|
||||
|
||||
event = AuditEvent(
|
||||
id=str(uuid4()),
|
||||
event_type=event_type,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_id=agent_id,
|
||||
action_id=action_id,
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
decision=decision,
|
||||
details=sanitized_details,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
# Add hash chain for tamper detection
|
||||
if self._enable_hash_chain:
|
||||
event_hash = self._compute_hash(event)
|
||||
# Modify event.details directly (not sanitized_details)
|
||||
# to ensure the hash is stored on the actual event
|
||||
event.details["_hash"] = event_hash
|
||||
event.details["_prev_hash"] = self._last_hash
|
||||
self._last_hash = event_hash
|
||||
|
||||
self._buffer.append(event)
|
||||
|
||||
# Notify handlers
|
||||
await self._notify_handlers(event)
|
||||
|
||||
# Log to standard logger as well
|
||||
self._log_to_logger(event)
|
||||
|
||||
return event
|
||||
|
||||
async def log_action_request(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
decision: SafetyDecision,
|
||||
reasons: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action request with its validation decision."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_DENIED
|
||||
if decision == SafetyDecision.DENY
|
||||
else AuditEventType.ACTION_VALIDATED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=action.metadata.user_id,
|
||||
decision=decision,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
"is_destructive": action.is_destructive,
|
||||
"reasons": reasons or [],
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_action_executed(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
success: bool,
|
||||
execution_time_ms: float,
|
||||
error: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action execution result."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_EXECUTED if success else AuditEventType.ACTION_FAILED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"success": success,
|
||||
"execution_time_ms": execution_time_ms,
|
||||
"error": error,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_approval_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
approval_id: str,
|
||||
action: ActionRequest,
|
||||
decided_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an approval-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=decided_by,
|
||||
details={
|
||||
"approval_id": approval_id,
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"decided_by": decided_by,
|
||||
"reason": reason,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_budget_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
agent_id: str,
|
||||
scope: str,
|
||||
current_usage: float,
|
||||
limit: float,
|
||||
unit: str = "tokens",
|
||||
) -> AuditEvent:
|
||||
"""Log a budget-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=agent_id,
|
||||
details={
|
||||
"scope": scope,
|
||||
"current_usage": current_usage,
|
||||
"limit": limit,
|
||||
"unit": unit,
|
||||
"usage_percent": (current_usage / limit * 100) if limit > 0 else 0,
|
||||
},
|
||||
)
|
||||
|
||||
async def log_emergency_stop(
|
||||
self,
|
||||
stop_type: str,
|
||||
triggered_by: str,
|
||||
reason: str,
|
||||
affected_agents: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an emergency stop event."""
|
||||
return await self.log(
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
user_id=triggered_by,
|
||||
details={
|
||||
"stop_type": stop_type,
|
||||
"triggered_by": triggered_by,
|
||||
"reason": reason,
|
||||
"affected_agents": affected_agents or [],
|
||||
},
|
||||
)
|
||||
|
||||
async def flush(self) -> int:
|
||||
"""
|
||||
Flush buffered events to persistent storage.
|
||||
|
||||
Returns:
|
||||
Number of events flushed
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._buffer:
|
||||
return 0
|
||||
|
||||
events = list(self._buffer)
|
||||
self._buffer.clear()
|
||||
|
||||
# Persist events (in production, this would go to database/storage)
|
||||
self._persisted.extend(events)
|
||||
|
||||
# Enforce retention
|
||||
self._enforce_retention()
|
||||
|
||||
logger.debug("Flushed %d audit events", len(events))
|
||||
return len(events)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
*,
|
||||
event_types: list[AuditEventType] | None = None,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
correlation_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[AuditEvent]:
|
||||
"""
|
||||
Query audit events with filters.
|
||||
|
||||
Args:
|
||||
event_types: Filter by event types
|
||||
agent_id: Filter by agent ID
|
||||
action_id: Filter by action ID
|
||||
project_id: Filter by project ID
|
||||
session_id: Filter by session ID
|
||||
user_id: Filter by user ID
|
||||
start_time: Filter events after this time
|
||||
end_time: Filter events before this time
|
||||
correlation_id: Filter by correlation ID
|
||||
limit: Maximum results to return
|
||||
offset: Result offset for pagination
|
||||
|
||||
Returns:
|
||||
List of matching audit events
|
||||
"""
|
||||
# Combine buffer and persisted for query
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
results = []
|
||||
for event in all_events:
|
||||
if event_types and event.event_type not in event_types:
|
||||
continue
|
||||
if agent_id and event.agent_id != agent_id:
|
||||
continue
|
||||
if action_id and event.action_id != action_id:
|
||||
continue
|
||||
if project_id and event.project_id != project_id:
|
||||
continue
|
||||
if session_id and event.session_id != session_id:
|
||||
continue
|
||||
if user_id and event.user_id != user_id:
|
||||
continue
|
||||
if start_time and event.timestamp < start_time:
|
||||
continue
|
||||
if end_time and event.timestamp > end_time:
|
||||
continue
|
||||
if correlation_id and event.correlation_id != correlation_id:
|
||||
continue
|
||||
|
||||
results.append(event)
|
||||
|
||||
# Sort by timestamp descending
|
||||
results.sort(key=lambda e: e.timestamp, reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
return results[offset : offset + limit]
|
||||
|
||||
async def get_action_history(
|
||||
self,
|
||||
agent_id: str,
|
||||
limit: int = 100,
|
||||
) -> list[AuditEvent]:
|
||||
"""Get action history for an agent."""
|
||||
return await self.query(
|
||||
agent_id=agent_id,
|
||||
event_types=[
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
AuditEventType.ACTION_VALIDATED,
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
AuditEventType.ACTION_FAILED,
|
||||
],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def verify_integrity(self) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Verify audit log integrity using hash chain.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list of issues found)
|
||||
"""
|
||||
if not self._enable_hash_chain:
|
||||
return True, []
|
||||
|
||||
issues: list[str] = []
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
prev_hash: str | None = None
|
||||
for event in sorted(all_events, key=lambda e: e.timestamp):
|
||||
stored_prev = event.details.get("_prev_hash")
|
||||
stored_hash = event.details.get("_hash")
|
||||
|
||||
if stored_prev != prev_hash:
|
||||
issues.append(
|
||||
f"Hash chain broken at event {event.id}: "
|
||||
f"expected prev_hash={prev_hash}, got {stored_prev}"
|
||||
)
|
||||
|
||||
if stored_hash:
|
||||
# Pass prev_hash to compute hash with correct chain position
|
||||
computed = self._compute_hash(event, prev_hash=prev_hash)
|
||||
if computed != stored_hash:
|
||||
issues.append(
|
||||
f"Hash mismatch at event {event.id}: "
|
||||
f"expected {computed}, got {stored_hash}"
|
||||
)
|
||||
|
||||
prev_hash = stored_hash
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
def add_handler(self, handler: Any) -> None:
|
||||
"""Add a real-time event handler."""
|
||||
self._handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: Any) -> None:
|
||||
"""Remove an event handler."""
|
||||
if handler in self._handlers:
|
||||
self._handlers.remove(handler)
|
||||
|
||||
def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitize sensitive data from details."""
|
||||
if self._include_sensitive:
|
||||
return details
|
||||
|
||||
sanitized: dict[str, Any] = {}
|
||||
sensitive_keys = {
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"auth",
|
||||
"credential",
|
||||
}
|
||||
|
||||
for key, value in details.items():
|
||||
lower_key = key.lower()
|
||||
if any(s in lower_key for s in sensitive_keys):
|
||||
sanitized[key] = "[REDACTED]"
|
||||
elif isinstance(value, dict):
|
||||
sanitized[key] = self._sanitize_details(value)
|
||||
else:
|
||||
sanitized[key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
def _compute_hash(
|
||||
self, event: AuditEvent, prev_hash: str | None | object = _UNSET
|
||||
) -> str:
|
||||
"""Compute hash for an event (excluding hash fields).
|
||||
|
||||
Args:
|
||||
event: The audit event to hash.
|
||||
prev_hash: Optional previous hash to use instead of self._last_hash.
|
||||
Pass this during verification to use the correct chain.
|
||||
Use None explicitly to indicate no previous hash.
|
||||
"""
|
||||
# Use passed prev_hash if explicitly provided, otherwise use instance state
|
||||
effective_prev: str | None = (
|
||||
self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment]
|
||||
)
|
||||
|
||||
data: dict[str, str | dict[str, str] | None] = {
|
||||
"id": event.id,
|
||||
"event_type": event.event_type.value,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"project_id": event.project_id,
|
||||
"session_id": event.session_id,
|
||||
"user_id": event.user_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
"details": {
|
||||
k: v for k, v in event.details.items() if not k.startswith("_")
|
||||
},
|
||||
"correlation_id": event.correlation_id,
|
||||
}
|
||||
|
||||
if effective_prev:
|
||||
data["_prev_hash"] = effective_prev
|
||||
|
||||
serialized = json.dumps(data, sort_keys=True, default=str)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
def _log_to_logger(self, event: AuditEvent) -> None:
|
||||
"""Log event to standard Python logger."""
|
||||
log_data = {
|
||||
"audit_event": event.event_type.value,
|
||||
"event_id": event.id,
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
}
|
||||
|
||||
# Use appropriate log level based on event type
|
||||
if event.event_type in {
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.POLICY_VIOLATION,
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
}:
|
||||
logger.warning("Audit: %s", log_data)
|
||||
elif event.event_type in {
|
||||
AuditEventType.ACTION_FAILED,
|
||||
AuditEventType.ROLLBACK_FAILED,
|
||||
}:
|
||||
logger.error("Audit: %s", log_data)
|
||||
else:
|
||||
logger.info("Audit: %s", log_data)
|
||||
|
||||
def _enforce_retention(self) -> None:
|
||||
"""Enforce retention policy on persisted events."""
|
||||
if not self._retention_days:
|
||||
return
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(days=self._retention_days)
|
||||
before_count = len(self._persisted)
|
||||
|
||||
self._persisted = [e for e in self._persisted if e.timestamp >= cutoff]
|
||||
|
||||
removed = before_count - len(self._persisted)
|
||||
if removed > 0:
|
||||
logger.info("Removed %d expired audit events", removed)
|
||||
|
||||
async def _periodic_flush(self) -> None:
|
||||
"""Background task for periodic flushing."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self._flush_interval)
|
||||
await self.flush()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in periodic audit flush: %s", e)
|
||||
|
||||
async def _notify_handlers(self, event: AuditEvent) -> None:
|
||||
"""Notify all registered handlers of a new event."""
|
||||
for handler in self._handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event)
|
||||
else:
|
||||
handler(event)
|
||||
except Exception as e:
|
||||
logger.error("Error in audit event handler: %s", e)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_audit_logger: AuditLogger | None = None
|
||||
_audit_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_audit_logger() -> AuditLogger:
|
||||
"""Get the global audit logger instance."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is None:
|
||||
_audit_logger = AuditLogger()
|
||||
await _audit_logger.start()
|
||||
|
||||
return _audit_logger
|
||||
|
||||
|
||||
async def shutdown_audit_logger() -> None:
|
||||
"""Shutdown the global audit logger."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is not None:
|
||||
await _audit_logger.stop()
|
||||
_audit_logger = None
|
||||
|
||||
|
||||
def reset_audit_logger() -> None:
|
||||
"""Reset the audit logger (for testing)."""
|
||||
global _audit_logger
|
||||
_audit_logger = None
|
||||
304
backend/app/services/safety/config.py
Normal file
304
backend/app/services/safety/config.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Safety Framework Configuration
|
||||
|
||||
Pydantic settings for the safety and guardrails framework.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .models import AutonomyLevel, SafetyPolicy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafetyConfig(BaseSettings):
|
||||
"""Configuration for the safety framework."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="SAFETY_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# General settings
|
||||
enabled: bool = Field(True, description="Enable safety framework")
|
||||
strict_mode: bool = Field(True, description="Strict mode (fail closed on errors)")
|
||||
log_level: str = Field("INFO", description="Logging level")
|
||||
|
||||
# Default autonomy level
|
||||
default_autonomy_level: AutonomyLevel = Field(
|
||||
AutonomyLevel.MILESTONE,
|
||||
description="Default autonomy level for new agents",
|
||||
)
|
||||
|
||||
# Default budget limits
|
||||
default_session_token_budget: int = Field(
|
||||
100_000, description="Default tokens per session"
|
||||
)
|
||||
default_daily_token_budget: int = Field(
|
||||
1_000_000, description="Default tokens per day"
|
||||
)
|
||||
default_session_cost_limit: float = Field(
|
||||
10.0, description="Default USD per session"
|
||||
)
|
||||
default_daily_cost_limit: float = Field(100.0, description="Default USD per day")
|
||||
|
||||
# Default rate limits
|
||||
default_actions_per_minute: int = Field(60, description="Default actions per min")
|
||||
default_llm_calls_per_minute: int = Field(20, description="Default LLM calls/min")
|
||||
default_file_ops_per_minute: int = Field(100, description="Default file ops/min")
|
||||
|
||||
# Loop detection
|
||||
loop_detection_enabled: bool = Field(True, description="Enable loop detection")
|
||||
max_repeated_actions: int = Field(5, description="Max exact repetitions")
|
||||
max_similar_actions: int = Field(10, description="Max similar actions")
|
||||
loop_history_size: int = Field(100, description="Action history size for loops")
|
||||
|
||||
# HITL settings
|
||||
hitl_enabled: bool = Field(True, description="Enable human-in-the-loop")
|
||||
hitl_default_timeout: int = Field(300, description="Default approval timeout (s)")
|
||||
hitl_notification_channels: list[str] = Field(
|
||||
default_factory=list, description="Notification channels"
|
||||
)
|
||||
|
||||
# Rollback settings
|
||||
rollback_enabled: bool = Field(True, description="Enable rollback capability")
|
||||
checkpoint_dir: str = Field(
|
||||
"/tmp/syndarix_checkpoints", # noqa: S108
|
||||
description="Directory for checkpoint storage",
|
||||
)
|
||||
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
|
||||
auto_checkpoint_destructive: bool = Field(
|
||||
True, description="Auto-checkpoint destructive actions"
|
||||
)
|
||||
|
||||
# Sandbox settings
|
||||
sandbox_enabled: bool = Field(False, description="Enable sandbox execution")
|
||||
sandbox_timeout: int = Field(300, description="Sandbox timeout (s)")
|
||||
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit (MB)")
|
||||
sandbox_cpu_limit: float = Field(1.0, description="Sandbox CPU limit")
|
||||
sandbox_network_enabled: bool = Field(False, description="Allow sandbox network")
|
||||
|
||||
# Audit settings
|
||||
audit_enabled: bool = Field(True, description="Enable audit logging")
|
||||
audit_retention_days: int = Field(90, description="Audit log retention (days)")
|
||||
audit_include_sensitive: bool = Field(
|
||||
False, description="Include sensitive data in audit"
|
||||
)
|
||||
|
||||
# Content filtering
|
||||
content_filter_enabled: bool = Field(True, description="Enable content filtering")
|
||||
filter_pii: bool = Field(True, description="Filter PII")
|
||||
filter_secrets: bool = Field(True, description="Filter secrets")
|
||||
|
||||
# Emergency controls
|
||||
emergency_stop_enabled: bool = Field(True, description="Enable emergency stop")
|
||||
emergency_webhook_url: str | None = Field(None, description="Emergency webhook")
|
||||
|
||||
# Policy file path
|
||||
policy_file: str | None = Field(None, description="Path to policy YAML file")
|
||||
|
||||
# Validation cache
|
||||
validation_cache_ttl: int = Field(60, description="Validation cache TTL (s)")
|
||||
validation_cache_size: int = Field(1000, description="Validation cache size")
|
||||
|
||||
|
||||
class AutonomyConfig(BaseSettings):
|
||||
"""Configuration for autonomy levels."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="AUTONOMY_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# FULL_CONTROL settings
|
||||
full_control_cost_limit: float = Field(1.0, description="USD limit per session")
|
||||
full_control_require_all_approval: bool = Field(
|
||||
True, description="Require approval for all"
|
||||
)
|
||||
full_control_block_destructive: bool = Field(
|
||||
True, description="Block destructive actions"
|
||||
)
|
||||
|
||||
# MILESTONE settings
|
||||
milestone_cost_limit: float = Field(10.0, description="USD limit per session")
|
||||
milestone_require_critical_approval: bool = Field(
|
||||
True, description="Require approval for critical"
|
||||
)
|
||||
milestone_auto_checkpoint: bool = Field(
|
||||
True, description="Auto-checkpoint destructive"
|
||||
)
|
||||
|
||||
# AUTONOMOUS settings
|
||||
autonomous_cost_limit: float = Field(100.0, description="USD limit per session")
|
||||
autonomous_auto_approve_normal: bool = Field(
|
||||
True, description="Auto-approve normal actions"
|
||||
)
|
||||
autonomous_auto_checkpoint: bool = Field(True, description="Auto-checkpoint all")
|
||||
|
||||
|
||||
def _expand_env_vars(value: Any) -> Any:
|
||||
"""Recursively expand environment variables in values."""
|
||||
if isinstance(value, str):
|
||||
return os.path.expandvars(value)
|
||||
elif isinstance(value, dict):
|
||||
return {k: _expand_env_vars(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_expand_env_vars(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def load_policy_from_file(file_path: str | Path) -> SafetyPolicy | None:
|
||||
"""Load a safety policy from a YAML file."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.warning("Policy file not found: %s", path)
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
logger.warning("Empty policy file: %s", path)
|
||||
return None
|
||||
|
||||
# Expand environment variables
|
||||
data = _expand_env_vars(data)
|
||||
|
||||
return SafetyPolicy(**data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load policy file %s: %s", path, e)
|
||||
return None
|
||||
|
||||
|
||||
def load_policies_from_directory(directory: str | Path) -> dict[str, SafetyPolicy]:
|
||||
"""Load all safety policies from a directory."""
|
||||
policies: dict[str, SafetyPolicy] = {}
|
||||
path = Path(directory)
|
||||
|
||||
if not path.exists() or not path.is_dir():
|
||||
logger.warning("Policy directory not found: %s", path)
|
||||
return policies
|
||||
|
||||
for file_path in path.glob("*.yaml"):
|
||||
policy = load_policy_from_file(file_path)
|
||||
if policy:
|
||||
policies[policy.name] = policy
|
||||
logger.info("Loaded policy: %s from %s", policy.name, file_path.name)
|
||||
|
||||
return policies
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_safety_config() -> SafetyConfig:
|
||||
"""Get the safety configuration (cached singleton)."""
|
||||
return SafetyConfig()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_autonomy_config() -> AutonomyConfig:
|
||||
"""Get the autonomy configuration (cached singleton)."""
|
||||
return AutonomyConfig()
|
||||
|
||||
|
||||
def get_default_policy() -> SafetyPolicy:
|
||||
"""Get the default safety policy."""
|
||||
config = get_safety_config()
|
||||
|
||||
return SafetyPolicy(
|
||||
name="default",
|
||||
description="Default safety policy",
|
||||
max_tokens_per_session=config.default_session_token_budget,
|
||||
max_tokens_per_day=config.default_daily_token_budget,
|
||||
max_cost_per_session_usd=config.default_session_cost_limit,
|
||||
max_cost_per_day_usd=config.default_daily_cost_limit,
|
||||
max_actions_per_minute=config.default_actions_per_minute,
|
||||
max_llm_calls_per_minute=config.default_llm_calls_per_minute,
|
||||
max_file_operations_per_minute=config.default_file_ops_per_minute,
|
||||
max_repeated_actions=config.max_repeated_actions,
|
||||
max_similar_actions=config.max_similar_actions,
|
||||
require_sandbox=config.sandbox_enabled,
|
||||
sandbox_timeout_seconds=config.sandbox_timeout,
|
||||
sandbox_memory_mb=config.sandbox_memory_mb,
|
||||
)
|
||||
|
||||
|
||||
def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
|
||||
"""Get the safety policy for a given autonomy level."""
|
||||
autonomy = get_autonomy_config()
|
||||
|
||||
base_policy = get_default_policy()
|
||||
|
||||
if level == AutonomyLevel.FULL_CONTROL:
|
||||
return SafetyPolicy(
|
||||
name="full_control",
|
||||
description="Full control mode - all actions require approval",
|
||||
max_cost_per_session_usd=autonomy.full_control_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.full_control_cost_limit * 10,
|
||||
require_approval_for=["*"], # All actions
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session // 10,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
|
||||
// 2,
|
||||
denied_tools=["delete_*", "destroy_*", "drop_*"],
|
||||
)
|
||||
|
||||
elif level == AutonomyLevel.MILESTONE:
|
||||
return SafetyPolicy(
|
||||
name="milestone",
|
||||
description="Milestone mode - approval at milestones only",
|
||||
max_cost_per_session_usd=autonomy.milestone_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.milestone_cost_limit * 10,
|
||||
require_approval_for=[
|
||||
"delete_file",
|
||||
"push_to_remote",
|
||||
"deploy_*",
|
||||
"modify_critical_*",
|
||||
"create_pull_request",
|
||||
],
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute,
|
||||
)
|
||||
|
||||
else: # AUTONOMOUS
|
||||
return SafetyPolicy(
|
||||
name="autonomous",
|
||||
description="Autonomous mode - minimal intervention",
|
||||
max_cost_per_session_usd=autonomy.autonomous_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.autonomous_cost_limit * 10,
|
||||
require_approval_for=[
|
||||
"deploy_to_production",
|
||||
"delete_repository",
|
||||
"modify_production_config",
|
||||
],
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session * 5,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
|
||||
* 2,
|
||||
)
|
||||
|
||||
|
||||
def reset_config_cache() -> None:
|
||||
"""Reset configuration caches (for testing)."""
|
||||
get_safety_config.cache_clear()
|
||||
get_autonomy_config.cache_clear()
|
||||
23
backend/app/services/safety/content/__init__.py
Normal file
23
backend/app/services/safety/content/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Content filtering for safety."""
|
||||
|
||||
from .filter import (
|
||||
ContentCategory,
|
||||
ContentFilter,
|
||||
FilterAction,
|
||||
FilterMatch,
|
||||
FilterPattern,
|
||||
FilterResult,
|
||||
filter_content,
|
||||
scan_for_secrets,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContentCategory",
|
||||
"ContentFilter",
|
||||
"FilterAction",
|
||||
"FilterMatch",
|
||||
"FilterPattern",
|
||||
"FilterResult",
|
||||
"filter_content",
|
||||
"scan_for_secrets",
|
||||
]
|
||||
550
backend/app/services/safety/content/filter.py
Normal file
550
backend/app/services/safety/content/filter.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""
|
||||
Content Filter
|
||||
|
||||
Filters and sanitizes content for safety, including PII detection and secret scanning.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field, replace
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..exceptions import ContentFilterError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentCategory(str, Enum):
|
||||
"""Categories of sensitive content."""
|
||||
|
||||
PII = "pii"
|
||||
SECRETS = "secrets"
|
||||
CREDENTIALS = "credentials"
|
||||
FINANCIAL = "financial"
|
||||
HEALTH = "health"
|
||||
PROFANITY = "profanity"
|
||||
INJECTION = "injection"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class FilterAction(str, Enum):
|
||||
"""Actions to take on detected content."""
|
||||
|
||||
ALLOW = "allow"
|
||||
REDACT = "redact"
|
||||
BLOCK = "block"
|
||||
WARN = "warn"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterMatch:
|
||||
"""A match found by a filter."""
|
||||
|
||||
category: ContentCategory
|
||||
pattern_name: str
|
||||
matched_text: str
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
confidence: float = 1.0
|
||||
redacted_text: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterResult:
|
||||
"""Result of content filtering."""
|
||||
|
||||
original_content: str
|
||||
filtered_content: str
|
||||
matches: list[FilterMatch] = field(default_factory=list)
|
||||
blocked: bool = False
|
||||
block_reason: str | None = None
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def has_sensitive_content(self) -> bool:
|
||||
"""Check if any sensitive content was found."""
|
||||
return len(self.matches) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterPattern:
|
||||
"""A pattern for detecting sensitive content."""
|
||||
|
||||
name: str
|
||||
category: ContentCategory
|
||||
pattern: str # Regex pattern
|
||||
action: FilterAction = FilterAction.REDACT
|
||||
replacement: str = "[REDACTED]"
|
||||
confidence: float = 1.0
|
||||
enabled: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Compile the regex pattern."""
|
||||
self._compiled = re.compile(self.pattern, re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
def find_matches(self, content: str) -> list[FilterMatch]:
|
||||
"""Find all matches in content."""
|
||||
matches = []
|
||||
for match in self._compiled.finditer(content):
|
||||
matches.append(
|
||||
FilterMatch(
|
||||
category=self.category,
|
||||
pattern_name=self.name,
|
||||
matched_text=match.group(),
|
||||
start_pos=match.start(),
|
||||
end_pos=match.end(),
|
||||
confidence=self.confidence,
|
||||
redacted_text=self.replacement,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
|
||||
class ContentFilter:
|
||||
"""
|
||||
Filters content for sensitive information.
|
||||
|
||||
Features:
|
||||
- PII detection (emails, phones, SSN, etc.)
|
||||
- Secret scanning (API keys, tokens, passwords)
|
||||
- Credential detection
|
||||
- Injection attack prevention
|
||||
- Custom pattern support
|
||||
- Configurable actions (allow, redact, block, warn)
|
||||
"""
|
||||
|
||||
# Default patterns for common sensitive data
|
||||
DEFAULT_PATTERNS: ClassVar[list[FilterPattern]] = [
|
||||
# PII Patterns
|
||||
FilterPattern(
|
||||
name="email",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[EMAIL]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="phone_us",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b(?:\+1[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[PHONE]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="ssn",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[SSN]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="credit_card",
|
||||
category=ContentCategory.FINANCIAL,
|
||||
pattern=r"\b(?:\d{4}[-\s]?){3}\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[CREDIT_CARD]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="ip_address",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[IP]",
|
||||
confidence=0.8,
|
||||
),
|
||||
# Secret Patterns
|
||||
FilterPattern(
|
||||
name="api_key_generic",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_-]{20,})['\"]?",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[API_KEY]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="aws_access_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\bAKIA[0-9A-Z]{16}\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[AWS_KEY]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="aws_secret_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b[A-Za-z0-9/+=]{40}\b",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[AWS_SECRET]",
|
||||
confidence=0.6, # Lower confidence - might be false positive
|
||||
),
|
||||
FilterPattern(
|
||||
name="github_token",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b(ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9]{36,}\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[GITHUB_TOKEN]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="jwt_token",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\beyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[JWT]",
|
||||
),
|
||||
# Credential Patterns
|
||||
FilterPattern(
|
||||
name="password_in_url",
|
||||
category=ContentCategory.CREDENTIALS,
|
||||
pattern=r"://[^:]+:([^@]+)@",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="://[REDACTED]@",
|
||||
),
|
||||
FilterPattern(
|
||||
name="password_assignment",
|
||||
category=ContentCategory.CREDENTIALS,
|
||||
pattern=r"\b(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]+)['\"]?",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[PASSWORD]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="private_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"-----BEGIN (?:RSA |DSA |EC |OPENSSH )?PRIVATE KEY-----",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[PRIVATE_KEY]",
|
||||
),
|
||||
# Injection Patterns
|
||||
FilterPattern(
|
||||
name="sql_injection",
|
||||
category=ContentCategory.INJECTION,
|
||||
pattern=r"(?:'\s*(?:OR|AND)\s*')|(?:--\s*$)|(?:;\s*(?:DROP|DELETE|UPDATE|INSERT))",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[BLOCKED]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="command_injection",
|
||||
category=ContentCategory.INJECTION,
|
||||
pattern=r"[;&|`$]|\$\(|\$\{",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[CMD]",
|
||||
confidence=0.5, # Low confidence - common in code
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_pii_filter: bool = True,
|
||||
enable_secret_filter: bool = True,
|
||||
enable_injection_filter: bool = True,
|
||||
custom_patterns: list[FilterPattern] | None = None,
|
||||
default_action: FilterAction = FilterAction.REDACT,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the ContentFilter.
|
||||
|
||||
Args:
|
||||
enable_pii_filter: Enable PII detection
|
||||
enable_secret_filter: Enable secret scanning
|
||||
enable_injection_filter: Enable injection detection
|
||||
custom_patterns: Additional custom patterns
|
||||
default_action: Default action for matches
|
||||
"""
|
||||
self._patterns: list[FilterPattern] = []
|
||||
self._default_action = default_action
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Load default patterns based on configuration
|
||||
# Use replace() to create a copy of each pattern to avoid mutating shared defaults
|
||||
for pattern in self.DEFAULT_PATTERNS:
|
||||
if pattern.category == ContentCategory.PII and not enable_pii_filter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
|
||||
continue
|
||||
if (
|
||||
pattern.category == ContentCategory.CREDENTIALS
|
||||
and not enable_secret_filter
|
||||
):
|
||||
continue
|
||||
if (
|
||||
pattern.category == ContentCategory.INJECTION
|
||||
and not enable_injection_filter
|
||||
):
|
||||
continue
|
||||
self._patterns.append(replace(pattern))
|
||||
|
||||
# Add custom patterns
|
||||
if custom_patterns:
|
||||
self._patterns.extend(custom_patterns)
|
||||
|
||||
logger.info("ContentFilter initialized with %d patterns", len(self._patterns))
|
||||
|
||||
def add_pattern(self, pattern: FilterPattern) -> None:
|
||||
"""Add a custom pattern."""
|
||||
self._patterns.append(pattern)
|
||||
logger.debug("Added pattern: %s", pattern.name)
|
||||
|
||||
def remove_pattern(self, pattern_name: str) -> bool:
|
||||
"""Remove a pattern by name."""
|
||||
for i, pattern in enumerate(self._patterns):
|
||||
if pattern.name == pattern_name:
|
||||
del self._patterns[i]
|
||||
logger.debug("Removed pattern: %s", pattern_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
def enable_pattern(self, pattern_name: str, enabled: bool = True) -> bool:
|
||||
"""Enable or disable a pattern."""
|
||||
for pattern in self._patterns:
|
||||
if pattern.name == pattern_name:
|
||||
pattern.enabled = enabled
|
||||
return True
|
||||
return False
|
||||
|
||||
async def filter(
|
||||
self,
|
||||
content: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
raise_on_block: bool = False,
|
||||
) -> FilterResult:
|
||||
"""
|
||||
Filter content for sensitive information.
|
||||
|
||||
Args:
|
||||
content: Content to filter
|
||||
context: Optional context for filtering decisions
|
||||
raise_on_block: Raise exception if content is blocked
|
||||
|
||||
Returns:
|
||||
FilterResult with filtered content and match details
|
||||
|
||||
Raises:
|
||||
ContentFilterError: If content is blocked and raise_on_block=True
|
||||
"""
|
||||
all_matches: list[FilterMatch] = []
|
||||
blocked = False
|
||||
block_reason: str | None = None
|
||||
warnings: list[str] = []
|
||||
|
||||
# Find all matches
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
for match in matches:
|
||||
all_matches.append(match)
|
||||
|
||||
if pattern.action == FilterAction.BLOCK:
|
||||
blocked = True
|
||||
block_reason = f"Blocked by pattern: {pattern.name}"
|
||||
elif pattern.action == FilterAction.WARN:
|
||||
warnings.append(
|
||||
f"Warning: {pattern.name} detected at position {match.start_pos}"
|
||||
)
|
||||
|
||||
# Sort matches by position (reverse for replacement)
|
||||
all_matches.sort(key=lambda m: m.start_pos, reverse=True)
|
||||
|
||||
# Apply redactions
|
||||
filtered_content = content
|
||||
for match in all_matches:
|
||||
matched_pattern = self._get_pattern(match.pattern_name)
|
||||
if matched_pattern and matched_pattern.action in (
|
||||
FilterAction.REDACT,
|
||||
FilterAction.BLOCK,
|
||||
):
|
||||
filtered_content = (
|
||||
filtered_content[: match.start_pos]
|
||||
+ (match.redacted_text or "[REDACTED]")
|
||||
+ filtered_content[match.end_pos :]
|
||||
)
|
||||
|
||||
# Re-sort for result
|
||||
all_matches.sort(key=lambda m: m.start_pos)
|
||||
|
||||
result = FilterResult(
|
||||
original_content=content,
|
||||
filtered_content=filtered_content if not blocked else "",
|
||||
matches=all_matches,
|
||||
blocked=blocked,
|
||||
block_reason=block_reason,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
if blocked:
|
||||
logger.warning(
|
||||
"Content blocked: %s (%d matches)",
|
||||
block_reason,
|
||||
len(all_matches),
|
||||
)
|
||||
if raise_on_block:
|
||||
raise ContentFilterError(
|
||||
block_reason or "Content blocked",
|
||||
filter_type=all_matches[0].category.value
|
||||
if all_matches
|
||||
else "unknown",
|
||||
detected_patterns=[m.pattern_name for m in all_matches]
|
||||
if all_matches
|
||||
else [],
|
||||
)
|
||||
elif all_matches:
|
||||
logger.debug(
|
||||
"Content filtered: %d matches, %d warnings",
|
||||
len(all_matches),
|
||||
len(warnings),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def filter_dict(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
keys_to_filter: list[str] | None = None,
|
||||
recursive: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Filter string values in a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary to filter
|
||||
keys_to_filter: Specific keys to filter (None = all)
|
||||
recursive: Filter nested dictionaries
|
||||
|
||||
Returns:
|
||||
Filtered dictionary
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
if keys_to_filter is None or key in keys_to_filter:
|
||||
filter_result = await self.filter(value)
|
||||
result[key] = filter_result.filtered_content
|
||||
else:
|
||||
result[key] = value
|
||||
elif isinstance(value, dict) and recursive:
|
||||
result[key] = await self.filter_dict(value, keys_to_filter, recursive)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [
|
||||
(await self.filter(item)).filtered_content
|
||||
if isinstance(item, str)
|
||||
else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
content: str,
|
||||
categories: list[ContentCategory] | None = None,
|
||||
) -> list[FilterMatch]:
|
||||
"""
|
||||
Scan content without filtering (detection only).
|
||||
|
||||
Args:
|
||||
content: Content to scan
|
||||
categories: Limit to specific categories
|
||||
|
||||
Returns:
|
||||
List of matches found
|
||||
"""
|
||||
all_matches: list[FilterMatch] = []
|
||||
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
if categories and pattern.category not in categories:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
all_matches.extend(matches)
|
||||
|
||||
all_matches.sort(key=lambda m: m.start_pos)
|
||||
return all_matches
|
||||
|
||||
async def validate_safe(
|
||||
self,
|
||||
content: str,
|
||||
categories: list[ContentCategory] | None = None,
|
||||
allow_warnings: bool = True,
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Validate that content is safe (no blocked patterns).
|
||||
|
||||
Args:
|
||||
content: Content to validate
|
||||
categories: Limit to specific categories
|
||||
allow_warnings: Allow content with warnings
|
||||
|
||||
Returns:
|
||||
Tuple of (is_safe, list of issues)
|
||||
"""
|
||||
issues: list[str] = []
|
||||
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
if categories and pattern.category not in categories:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
for match in matches:
|
||||
if pattern.action == FilterAction.BLOCK:
|
||||
issues.append(
|
||||
f"Blocked: {pattern.name} at position {match.start_pos}"
|
||||
)
|
||||
elif pattern.action == FilterAction.WARN and not allow_warnings:
|
||||
issues.append(
|
||||
f"Warning: {pattern.name} at position {match.start_pos}"
|
||||
)
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
def _get_pattern(self, name: str) -> FilterPattern | None:
|
||||
"""Get a pattern by name."""
|
||||
for pattern in self._patterns:
|
||||
if pattern.name == name:
|
||||
return pattern
|
||||
return None
|
||||
|
||||
def get_pattern_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about configured patterns."""
|
||||
by_category: dict[str, int] = {}
|
||||
by_action: dict[str, int] = {}
|
||||
|
||||
for pattern in self._patterns:
|
||||
cat = pattern.category.value
|
||||
by_category[cat] = by_category.get(cat, 0) + 1
|
||||
|
||||
act = pattern.action.value
|
||||
by_action[act] = by_action.get(act, 0) + 1
|
||||
|
||||
return {
|
||||
"total_patterns": len(self._patterns),
|
||||
"enabled_patterns": sum(1 for p in self._patterns if p.enabled),
|
||||
"by_category": by_category,
|
||||
"by_action": by_action,
|
||||
}
|
||||
|
||||
|
||||
# Convenience function for quick filtering
|
||||
async def filter_content(content: str) -> str:
|
||||
"""Quick filter content with default settings."""
|
||||
filter_instance = ContentFilter()
|
||||
result = await filter_instance.filter(content)
|
||||
return result.filtered_content
|
||||
|
||||
|
||||
async def scan_for_secrets(content: str) -> list[FilterMatch]:
|
||||
"""Quick scan for secrets only."""
|
||||
filter_instance = ContentFilter(
|
||||
enable_pii_filter=False,
|
||||
enable_injection_filter=False,
|
||||
)
|
||||
return await filter_instance.scan(
|
||||
content,
|
||||
categories=[ContentCategory.SECRETS, ContentCategory.CREDENTIALS],
|
||||
)
|
||||
15
backend/app/services/safety/costs/__init__.py
Normal file
15
backend/app/services/safety/costs/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Cost Control Module
|
||||
|
||||
Budget management and cost tracking.
|
||||
"""
|
||||
|
||||
from .controller import (
|
||||
BudgetTracker,
|
||||
CostController,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BudgetTracker",
|
||||
"CostController",
|
||||
]
|
||||
498
backend/app/services/safety/costs/controller.py
Normal file
498
backend/app/services/safety/costs/controller.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
Cost Controller
|
||||
|
||||
Budget management and cost tracking for agent operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BudgetTracker:
|
||||
"""Tracks usage against a budget limit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
tokens_limit: int,
|
||||
cost_limit_usd: float,
|
||||
reset_interval: timedelta | None = None,
|
||||
warning_threshold: float = 0.8,
|
||||
) -> None:
|
||||
self.scope = scope
|
||||
self.scope_id = scope_id
|
||||
self.tokens_limit = tokens_limit
|
||||
self.cost_limit_usd = cost_limit_usd
|
||||
self.warning_threshold = warning_threshold
|
||||
self._reset_interval = reset_interval
|
||||
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._created_at = datetime.utcnow()
|
||||
self._last_reset = datetime.utcnow()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add_usage(self, tokens: int, cost_usd: float) -> None:
|
||||
"""Add usage to the tracker."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
self._tokens_used += tokens
|
||||
self._cost_used_usd += cost_usd
|
||||
|
||||
async def get_status(self) -> BudgetStatus:
|
||||
"""Get current budget status."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
|
||||
tokens_remaining = max(0, self.tokens_limit - self._tokens_used)
|
||||
cost_remaining = max(0, self.cost_limit_usd - self._cost_used_usd)
|
||||
|
||||
token_usage_ratio = (
|
||||
self._tokens_used / self.tokens_limit if self.tokens_limit > 0 else 0
|
||||
)
|
||||
cost_usage_ratio = (
|
||||
self._cost_used_usd / self.cost_limit_usd
|
||||
if self.cost_limit_usd > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
is_warning = (
|
||||
max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
|
||||
)
|
||||
is_exceeded = (
|
||||
self._tokens_used >= self.tokens_limit
|
||||
or self._cost_used_usd >= self.cost_limit_usd
|
||||
)
|
||||
|
||||
reset_at = None
|
||||
if self._reset_interval:
|
||||
reset_at = self._last_reset + self._reset_interval
|
||||
|
||||
return BudgetStatus(
|
||||
scope=self.scope,
|
||||
scope_id=self.scope_id,
|
||||
tokens_used=self._tokens_used,
|
||||
tokens_limit=self.tokens_limit,
|
||||
cost_used_usd=self._cost_used_usd,
|
||||
cost_limit_usd=self.cost_limit_usd,
|
||||
tokens_remaining=tokens_remaining,
|
||||
cost_remaining_usd=cost_remaining,
|
||||
warning_threshold=self.warning_threshold,
|
||||
is_warning=is_warning,
|
||||
is_exceeded=is_exceeded,
|
||||
reset_at=reset_at,
|
||||
)
|
||||
|
||||
async def check_budget(
|
||||
self, estimated_tokens: int, estimated_cost_usd: float
|
||||
) -> bool:
|
||||
"""Check if there's enough budget for an operation."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
|
||||
would_exceed_tokens = (
|
||||
self._tokens_used + estimated_tokens
|
||||
) > self.tokens_limit
|
||||
would_exceed_cost = (
|
||||
self._cost_used_usd + estimated_cost_usd
|
||||
) > self.cost_limit_usd
|
||||
|
||||
return not (would_exceed_tokens or would_exceed_cost)
|
||||
|
||||
def _check_reset(self) -> None:
|
||||
"""Check if budget should reset."""
|
||||
if self._reset_interval is None:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
if now >= self._last_reset + self._reset_interval:
|
||||
logger.info(
|
||||
"Resetting budget for %s:%s",
|
||||
self.scope.value,
|
||||
self.scope_id,
|
||||
)
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._last_reset = now
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Manually reset the budget."""
|
||||
async with self._lock:
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._last_reset = datetime.utcnow()
|
||||
|
||||
|
||||
class CostController:
|
||||
"""
|
||||
Controls costs and budgets for agent operations.
|
||||
|
||||
Features:
|
||||
- Per-agent, per-project, per-session budgets
|
||||
- Real-time cost tracking
|
||||
- Budget alerts at configurable thresholds
|
||||
- Cost prediction for planned actions
|
||||
- Budget rollover policies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_session_tokens: int | None = None,
|
||||
default_session_cost_usd: float | None = None,
|
||||
default_daily_tokens: int | None = None,
|
||||
default_daily_cost_usd: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the CostController.
|
||||
|
||||
Args:
|
||||
default_session_tokens: Default token budget per session
|
||||
default_session_cost_usd: Default USD budget per session
|
||||
default_daily_tokens: Default token budget per day
|
||||
default_daily_cost_usd: Default USD budget per day
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._default_session_tokens = (
|
||||
default_session_tokens or config.default_session_token_budget
|
||||
)
|
||||
self._default_session_cost = (
|
||||
default_session_cost_usd or config.default_session_cost_limit
|
||||
)
|
||||
self._default_daily_tokens = (
|
||||
default_daily_tokens or config.default_daily_token_budget
|
||||
)
|
||||
self._default_daily_cost = (
|
||||
default_daily_cost_usd or config.default_daily_cost_limit
|
||||
)
|
||||
|
||||
self._trackers: dict[str, BudgetTracker] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Alert handlers
|
||||
self._alert_handlers: list[Any] = []
|
||||
|
||||
# Track which budgets have had warning alerts sent (to avoid spam)
|
||||
self._warned_budgets: set[str] = set()
|
||||
|
||||
async def get_or_create_tracker(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
) -> BudgetTracker:
|
||||
"""Get or create a budget tracker."""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
|
||||
async with self._lock:
|
||||
if key not in self._trackers:
|
||||
if scope == BudgetScope.SESSION:
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_session_tokens,
|
||||
cost_limit_usd=self._default_session_cost,
|
||||
)
|
||||
elif scope == BudgetScope.DAILY:
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_daily_tokens,
|
||||
cost_limit_usd=self._default_daily_cost,
|
||||
reset_interval=timedelta(days=1),
|
||||
)
|
||||
else:
|
||||
# Default
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_session_tokens,
|
||||
cost_limit_usd=self._default_session_cost,
|
||||
)
|
||||
|
||||
self._trackers[key] = tracker
|
||||
|
||||
return self._trackers[key]
|
||||
|
||||
async def check_budget(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
estimated_tokens: int,
|
||||
estimated_cost_usd: float,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if there's enough budget for an operation.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
estimated_tokens: Estimated token usage
|
||||
estimated_cost_usd: Estimated USD cost
|
||||
|
||||
Returns:
|
||||
True if budget is available
|
||||
"""
|
||||
# Check session budget
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
if not await session_tracker.check_budget(
|
||||
estimated_tokens, estimated_cost_usd
|
||||
):
|
||||
return False
|
||||
|
||||
# Check agent daily budget
|
||||
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
||||
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def check_action(self, action: ActionRequest) -> bool:
|
||||
"""
|
||||
Check if an action is within budget.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
True if within budget
|
||||
"""
|
||||
return await self.check_budget(
|
||||
agent_id=action.metadata.agent_id,
|
||||
session_id=action.metadata.session_id,
|
||||
estimated_tokens=action.estimated_cost_tokens,
|
||||
estimated_cost_usd=action.estimated_cost_usd,
|
||||
)
|
||||
|
||||
async def require_budget(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
estimated_tokens: int,
|
||||
estimated_cost_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Require budget or raise exception.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
estimated_tokens: Estimated token usage
|
||||
estimated_cost_usd: Estimated USD cost
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If budget is exceeded
|
||||
"""
|
||||
if not await self.check_budget(
|
||||
agent_id, session_id, estimated_tokens, estimated_cost_usd
|
||||
):
|
||||
# Determine which budget was exceeded
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
session_status = await session_tracker.get_status()
|
||||
if session_status.is_exceeded:
|
||||
raise BudgetExceededError(
|
||||
"Session budget exceeded",
|
||||
budget_type="session",
|
||||
current_usage=session_status.tokens_used,
|
||||
budget_limit=session_status.tokens_limit,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
agent_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.DAILY, agent_id
|
||||
)
|
||||
agent_status = await agent_tracker.get_status()
|
||||
raise BudgetExceededError(
|
||||
"Daily budget exceeded",
|
||||
budget_type="daily",
|
||||
current_usage=agent_status.tokens_used,
|
||||
budget_limit=agent_status.tokens_limit,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
tokens: int,
|
||||
cost_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Record actual usage.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
tokens: Actual token usage
|
||||
cost_usd: Actual USD cost
|
||||
"""
|
||||
# Update session budget
|
||||
if session_id:
|
||||
session_key = f"session:{session_id}"
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
await session_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning (only alert once per budget to avoid spam)
|
||||
status = await session_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
if session_key not in self._warned_budgets:
|
||||
self._warned_budgets.add(session_key)
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
|
||||
status,
|
||||
)
|
||||
elif not status.is_warning:
|
||||
# Clear warning flag if usage dropped below threshold (e.g., after reset)
|
||||
self._warned_budgets.discard(session_key)
|
||||
|
||||
# Update agent daily budget
|
||||
daily_key = f"daily:{agent_id}"
|
||||
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
||||
await agent_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning (only alert once per budget to avoid spam)
|
||||
status = await agent_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
if daily_key not in self._warned_budgets:
|
||||
self._warned_budgets.add(daily_key)
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
|
||||
status,
|
||||
)
|
||||
elif not status.is_warning:
|
||||
# Clear warning flag if usage dropped below threshold (e.g., after reset)
|
||||
self._warned_budgets.discard(daily_key)
|
||||
|
||||
async def get_status(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
) -> BudgetStatus | None:
|
||||
"""
|
||||
Get budget status.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
|
||||
Returns:
|
||||
Budget status or None if not tracked
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
# Get status while holding lock to prevent TOCTOU race
|
||||
if tracker:
|
||||
return await tracker.get_status()
|
||||
return None
|
||||
|
||||
async def get_all_statuses(self) -> list[BudgetStatus]:
|
||||
"""Get status of all tracked budgets."""
|
||||
statuses = []
|
||||
async with self._lock:
|
||||
# Get all statuses while holding lock to prevent TOCTOU race
|
||||
for tracker in self._trackers.values():
|
||||
statuses.append(await tracker.get_status())
|
||||
return statuses
|
||||
|
||||
async def set_budget(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
tokens_limit: int,
|
||||
cost_limit_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Set a custom budget limit.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
tokens_limit: Token limit
|
||||
cost_limit_usd: USD limit
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
|
||||
reset_interval = None
|
||||
if scope == BudgetScope.DAILY:
|
||||
reset_interval = timedelta(days=1)
|
||||
elif scope == BudgetScope.WEEKLY:
|
||||
reset_interval = timedelta(weeks=1)
|
||||
elif scope == BudgetScope.MONTHLY:
|
||||
reset_interval = timedelta(days=30)
|
||||
|
||||
async with self._lock:
|
||||
self._trackers[key] = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=tokens_limit,
|
||||
cost_limit_usd=cost_limit_usd,
|
||||
reset_interval=reset_interval,
|
||||
)
|
||||
|
||||
async def reset_budget(self, scope: BudgetScope, scope_id: str) -> bool:
|
||||
"""
|
||||
Reset a budget tracker.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
|
||||
Returns:
|
||||
True if tracker was found and reset
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
# Reset while holding lock to prevent TOCTOU race
|
||||
if tracker:
|
||||
await tracker.reset()
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_alert_handler(self, handler: Any) -> None:
|
||||
"""Add an alert handler."""
|
||||
self._alert_handlers.append(handler)
|
||||
|
||||
def remove_alert_handler(self, handler: Any) -> None:
|
||||
"""Remove an alert handler."""
|
||||
if handler in self._alert_handlers:
|
||||
self._alert_handlers.remove(handler)
|
||||
|
||||
async def _send_alert(
|
||||
self,
|
||||
alert_type: str,
|
||||
message: str,
|
||||
status: BudgetStatus,
|
||||
) -> None:
|
||||
"""Send alert to all handlers."""
|
||||
for handler in self._alert_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(alert_type, message, status)
|
||||
else:
|
||||
handler(alert_type, message, status)
|
||||
except Exception as e:
|
||||
logger.error("Error in alert handler: %s", e)
|
||||
23
backend/app/services/safety/emergency/__init__.py
Normal file
23
backend/app/services/safety/emergency/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Emergency controls for agent safety."""
|
||||
|
||||
from .controls import (
|
||||
EmergencyControls,
|
||||
EmergencyEvent,
|
||||
EmergencyReason,
|
||||
EmergencyState,
|
||||
EmergencyTrigger,
|
||||
check_emergency_allowed,
|
||||
emergency_stop_global,
|
||||
get_emergency_controls,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmergencyControls",
|
||||
"EmergencyEvent",
|
||||
"EmergencyReason",
|
||||
"EmergencyState",
|
||||
"EmergencyTrigger",
|
||||
"check_emergency_allowed",
|
||||
"emergency_stop_global",
|
||||
"get_emergency_controls",
|
||||
]
|
||||
596
backend/app/services/safety/emergency/controls.py
Normal file
596
backend/app/services/safety/emergency/controls.py
Normal file
@@ -0,0 +1,596 @@
|
||||
"""
|
||||
Emergency Controls
|
||||
|
||||
Emergency stop and pause functionality for agent safety.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import EmergencyStopError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmergencyState(str, Enum):
|
||||
"""Emergency control states."""
|
||||
|
||||
NORMAL = "normal"
|
||||
PAUSED = "paused"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
class EmergencyReason(str, Enum):
|
||||
"""Reasons for emergency actions."""
|
||||
|
||||
MANUAL = "manual"
|
||||
SAFETY_VIOLATION = "safety_violation"
|
||||
BUDGET_EXCEEDED = "budget_exceeded"
|
||||
LOOP_DETECTED = "loop_detected"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
CONTENT_VIOLATION = "content_violation"
|
||||
SYSTEM_ERROR = "system_error"
|
||||
EXTERNAL_TRIGGER = "external_trigger"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmergencyEvent:
|
||||
"""Record of an emergency action."""
|
||||
|
||||
id: str
|
||||
state: EmergencyState
|
||||
reason: EmergencyReason
|
||||
triggered_by: str
|
||||
message: str
|
||||
scope: str # "global", "project:<id>", "agent:<id>"
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
resolved_at: datetime | None = None
|
||||
resolved_by: str | None = None
|
||||
|
||||
|
||||
class EmergencyControls:
|
||||
"""
|
||||
Emergency stop and pause controls for agent safety.
|
||||
|
||||
Features:
|
||||
- Global emergency stop
|
||||
- Per-project/agent emergency controls
|
||||
- Graceful pause with state preservation
|
||||
- Automatic triggers from safety violations
|
||||
- Manual override capabilities
|
||||
- Event history and audit trail
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notification_handlers: list[Callable[..., Any]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize EmergencyControls.
|
||||
|
||||
Args:
|
||||
notification_handlers: Handlers to call on emergency events
|
||||
"""
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
self._scoped_states: dict[str, EmergencyState] = {}
|
||||
self._events: list[EmergencyEvent] = []
|
||||
self._notification_handlers = notification_handlers or []
|
||||
self._lock = asyncio.Lock()
|
||||
self._event_id_counter = 0
|
||||
|
||||
# Callbacks for state changes
|
||||
self._on_stop_callbacks: list[Callable[..., Any]] = []
|
||||
self._on_pause_callbacks: list[Callable[..., Any]] = []
|
||||
self._on_resume_callbacks: list[Callable[..., Any]] = []
|
||||
|
||||
def _generate_event_id(self) -> str:
|
||||
"""Generate a unique event ID."""
|
||||
self._event_id_counter += 1
|
||||
return f"emerg-{self._event_id_counter:06d}"
|
||||
|
||||
async def emergency_stop(
|
||||
self,
|
||||
reason: EmergencyReason,
|
||||
triggered_by: str,
|
||||
message: str,
|
||||
scope: str = "global",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency stop.
|
||||
|
||||
Args:
|
||||
reason: Reason for the stop
|
||||
triggered_by: Who/what triggered the stop
|
||||
message: Human-readable message
|
||||
scope: Scope of the stop (global, project:<id>, agent:<id>)
|
||||
metadata: Additional context
|
||||
|
||||
Returns:
|
||||
The emergency event record
|
||||
"""
|
||||
async with self._lock:
|
||||
event = EmergencyEvent(
|
||||
id=self._generate_event_id(),
|
||||
state=EmergencyState.STOPPED,
|
||||
reason=reason,
|
||||
triggered_by=triggered_by,
|
||||
message=message,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.STOPPED
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.STOPPED
|
||||
|
||||
self._events.append(event)
|
||||
|
||||
logger.critical(
|
||||
"EMERGENCY STOP: scope=%s, reason=%s, by=%s - %s",
|
||||
scope,
|
||||
reason.value,
|
||||
triggered_by,
|
||||
message,
|
||||
)
|
||||
|
||||
# Execute callbacks
|
||||
await self._execute_callbacks(self._on_stop_callbacks, event)
|
||||
await self._notify_handlers("emergency_stop", event)
|
||||
|
||||
return event
|
||||
|
||||
async def pause(
|
||||
self,
|
||||
reason: EmergencyReason,
|
||||
triggered_by: str,
|
||||
message: str,
|
||||
scope: str = "global",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Pause operations (can be resumed).
|
||||
|
||||
Args:
|
||||
reason: Reason for the pause
|
||||
triggered_by: Who/what triggered the pause
|
||||
message: Human-readable message
|
||||
scope: Scope of the pause
|
||||
metadata: Additional context
|
||||
|
||||
Returns:
|
||||
The emergency event record
|
||||
"""
|
||||
async with self._lock:
|
||||
event = EmergencyEvent(
|
||||
id=self._generate_event_id(),
|
||||
state=EmergencyState.PAUSED,
|
||||
reason=reason,
|
||||
triggered_by=triggered_by,
|
||||
message=message,
|
||||
scope=scope,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.PAUSED
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.PAUSED
|
||||
|
||||
self._events.append(event)
|
||||
|
||||
logger.warning(
|
||||
"PAUSE: scope=%s, reason=%s, by=%s - %s",
|
||||
scope,
|
||||
reason.value,
|
||||
triggered_by,
|
||||
message,
|
||||
)
|
||||
|
||||
await self._execute_callbacks(self._on_pause_callbacks, event)
|
||||
await self._notify_handlers("pause", event)
|
||||
|
||||
return event
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
scope: str = "global",
|
||||
resumed_by: str = "system",
|
||||
message: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Resume operations from paused state.
|
||||
|
||||
Args:
|
||||
scope: Scope to resume
|
||||
resumed_by: Who/what is resuming
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if resumed, False if not in paused state
|
||||
"""
|
||||
async with self._lock:
|
||||
current_state = self._get_state(scope)
|
||||
|
||||
if current_state == EmergencyState.STOPPED:
|
||||
logger.warning(
|
||||
"Cannot resume from STOPPED state: %s (requires reset)",
|
||||
scope,
|
||||
)
|
||||
return False
|
||||
|
||||
if current_state == EmergencyState.NORMAL:
|
||||
return True # Already normal
|
||||
|
||||
# Find the pause event and mark as resolved
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.state == EmergencyState.PAUSED:
|
||||
if event.resolved_at is None:
|
||||
event.resolved_at = datetime.utcnow()
|
||||
event.resolved_by = resumed_by
|
||||
break
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||
|
||||
logger.info(
|
||||
"RESUMED: scope=%s, by=%s%s",
|
||||
scope,
|
||||
resumed_by,
|
||||
f" - {message}" if message else "",
|
||||
)
|
||||
|
||||
await self._execute_callbacks(
|
||||
self._on_resume_callbacks,
|
||||
{"scope": scope, "resumed_by": resumed_by},
|
||||
)
|
||||
await self._notify_handlers(
|
||||
"resume", {"scope": scope, "resumed_by": resumed_by}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def reset(
|
||||
self,
|
||||
scope: str = "global",
|
||||
reset_by: str = "admin",
|
||||
message: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Reset from stopped state (requires explicit action).
|
||||
|
||||
Args:
|
||||
scope: Scope to reset
|
||||
reset_by: Who is resetting (should be admin)
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if reset successful
|
||||
"""
|
||||
async with self._lock:
|
||||
current_state = self._get_state(scope)
|
||||
|
||||
if current_state == EmergencyState.NORMAL:
|
||||
return True
|
||||
|
||||
# Find the stop event and mark as resolved
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.state == EmergencyState.STOPPED:
|
||||
if event.resolved_at is None:
|
||||
event.resolved_at = datetime.utcnow()
|
||||
event.resolved_by = reset_by
|
||||
break
|
||||
|
||||
if scope == "global":
|
||||
self._global_state = EmergencyState.NORMAL
|
||||
else:
|
||||
self._scoped_states[scope] = EmergencyState.NORMAL
|
||||
|
||||
logger.warning(
|
||||
"EMERGENCY RESET: scope=%s, by=%s%s",
|
||||
scope,
|
||||
reset_by,
|
||||
f" - {message}" if message else "",
|
||||
)
|
||||
|
||||
await self._notify_handlers("reset", {"scope": scope, "reset_by": reset_by})
|
||||
|
||||
return True
|
||||
|
||||
async def check_allowed(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
raise_if_blocked: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if operations are allowed.
|
||||
|
||||
Args:
|
||||
scope: Specific scope to check (also checks global)
|
||||
raise_if_blocked: Raise exception if blocked
|
||||
|
||||
Returns:
|
||||
True if operations are allowed
|
||||
|
||||
Raises:
|
||||
EmergencyStopError: If blocked and raise_if_blocked=True
|
||||
"""
|
||||
async with self._lock:
|
||||
# Always check global state
|
||||
if self._global_state != EmergencyState.NORMAL:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Global emergency state: {self._global_state.value}",
|
||||
stop_type=self._get_last_reason("global") or "emergency",
|
||||
triggered_by=self._get_last_triggered_by("global"),
|
||||
)
|
||||
return False
|
||||
|
||||
# Check specific scope
|
||||
if scope and scope in self._scoped_states:
|
||||
state = self._scoped_states[scope]
|
||||
if state != EmergencyState.NORMAL:
|
||||
if raise_if_blocked:
|
||||
raise EmergencyStopError(
|
||||
f"Emergency state for {scope}: {state.value}",
|
||||
stop_type=self._get_last_reason(scope) or "emergency",
|
||||
triggered_by=self._get_last_triggered_by(scope),
|
||||
details={"scope": scope},
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_state(self, scope: str) -> EmergencyState:
|
||||
"""Get state for a scope."""
|
||||
if scope == "global":
|
||||
return self._global_state
|
||||
return self._scoped_states.get(scope, EmergencyState.NORMAL)
|
||||
|
||||
def _get_last_reason(self, scope: str) -> str:
|
||||
"""Get reason from last event for scope."""
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.resolved_at is None:
|
||||
return event.reason.value
|
||||
return "unknown"
|
||||
|
||||
def _get_last_triggered_by(self, scope: str) -> str:
|
||||
"""Get triggered_by from last event for scope."""
|
||||
for event in reversed(self._events):
|
||||
if event.scope == scope and event.resolved_at is None:
|
||||
return event.triggered_by
|
||||
return "unknown"
|
||||
|
||||
async def get_state(self, scope: str = "global") -> EmergencyState:
|
||||
"""Get current state for a scope."""
|
||||
async with self._lock:
|
||||
return self._get_state(scope)
|
||||
|
||||
async def get_all_states(self) -> dict[str, EmergencyState]:
|
||||
"""Get all current states."""
|
||||
async with self._lock:
|
||||
states = {"global": self._global_state}
|
||||
states.update(self._scoped_states)
|
||||
return states
|
||||
|
||||
async def get_active_events(self) -> list[EmergencyEvent]:
|
||||
"""Get all unresolved emergency events."""
|
||||
async with self._lock:
|
||||
return [e for e in self._events if e.resolved_at is None]
|
||||
|
||||
async def get_event_history(
|
||||
self,
|
||||
scope: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[EmergencyEvent]:
|
||||
"""Get emergency event history."""
|
||||
async with self._lock:
|
||||
events = list(self._events)
|
||||
|
||||
if scope:
|
||||
events = [e for e in events if e.scope == scope]
|
||||
|
||||
return events[-limit:]
|
||||
|
||||
def on_stop(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for stop events."""
|
||||
self._on_stop_callbacks.append(callback)
|
||||
|
||||
def on_pause(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for pause events."""
|
||||
self._on_pause_callbacks.append(callback)
|
||||
|
||||
def on_resume(self, callback: Callable[..., Any]) -> None:
|
||||
"""Register callback for resume events."""
|
||||
self._on_resume_callbacks.append(callback)
|
||||
|
||||
def add_notification_handler(self, handler: Callable[..., Any]) -> None:
|
||||
"""Add a notification handler."""
|
||||
self._notification_handlers.append(handler)
|
||||
|
||||
async def _execute_callbacks(
|
||||
self,
|
||||
callbacks: list[Callable[..., Any]],
|
||||
data: Any,
|
||||
) -> None:
|
||||
"""Execute callbacks safely."""
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data)
|
||||
else:
|
||||
callback(data)
|
||||
except Exception as e:
|
||||
logger.error("Error in callback: %s", e)
|
||||
|
||||
async def _notify_handlers(self, event_type: str, data: Any) -> None:
|
||||
"""Notify all handlers of an event."""
|
||||
for handler in self._notification_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event_type, data)
|
||||
else:
|
||||
handler(event_type, data)
|
||||
except Exception as e:
|
||||
logger.error("Error in notification handler: %s", e)
|
||||
|
||||
|
||||
class EmergencyTrigger:
|
||||
"""
|
||||
Automatic emergency triggers based on conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, controls: EmergencyControls) -> None:
|
||||
"""
|
||||
Initialize EmergencyTrigger.
|
||||
|
||||
Args:
|
||||
controls: EmergencyControls instance to trigger
|
||||
"""
|
||||
self._controls = controls
|
||||
|
||||
async def trigger_on_safety_violation(
|
||||
self,
|
||||
violation_type: str,
|
||||
details: dict[str, Any],
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from safety violation.
|
||||
|
||||
Args:
|
||||
violation_type: Type of violation
|
||||
details: Violation details
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.emergency_stop(
|
||||
reason=EmergencyReason.SAFETY_VIOLATION,
|
||||
triggered_by="safety_system",
|
||||
message=f"Safety violation: {violation_type}",
|
||||
scope=scope,
|
||||
metadata={"violation_type": violation_type, **details},
|
||||
)
|
||||
|
||||
async def trigger_on_budget_exceeded(
|
||||
self,
|
||||
budget_type: str,
|
||||
current: float,
|
||||
limit: float,
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from budget exceeded.
|
||||
|
||||
Args:
|
||||
budget_type: Type of budget
|
||||
current: Current usage
|
||||
limit: Budget limit
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.pause(
|
||||
reason=EmergencyReason.BUDGET_EXCEEDED,
|
||||
triggered_by="budget_controller",
|
||||
message=f"Budget exceeded: {budget_type} ({current:.2f}/{limit:.2f})",
|
||||
scope=scope,
|
||||
metadata={"budget_type": budget_type, "current": current, "limit": limit},
|
||||
)
|
||||
|
||||
async def trigger_on_loop_detected(
|
||||
self,
|
||||
loop_type: str,
|
||||
agent_id: str,
|
||||
details: dict[str, Any],
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from loop detection.
|
||||
|
||||
Args:
|
||||
loop_type: Type of loop
|
||||
agent_id: Agent that's looping
|
||||
details: Loop details
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.pause(
|
||||
reason=EmergencyReason.LOOP_DETECTED,
|
||||
triggered_by="loop_detector",
|
||||
message=f"Loop detected: {loop_type} in agent {agent_id}",
|
||||
scope=f"agent:{agent_id}",
|
||||
metadata={"loop_type": loop_type, "agent_id": agent_id, **details},
|
||||
)
|
||||
|
||||
async def trigger_on_content_violation(
|
||||
self,
|
||||
category: str,
|
||||
pattern: str,
|
||||
scope: str = "global",
|
||||
) -> EmergencyEvent:
|
||||
"""
|
||||
Trigger emergency from content violation.
|
||||
|
||||
Args:
|
||||
category: Content category
|
||||
pattern: Pattern that matched
|
||||
scope: Scope for the emergency
|
||||
|
||||
Returns:
|
||||
Emergency event
|
||||
"""
|
||||
return await self._controls.emergency_stop(
|
||||
reason=EmergencyReason.CONTENT_VIOLATION,
|
||||
triggered_by="content_filter",
|
||||
message=f"Content violation: {category} ({pattern})",
|
||||
scope=scope,
|
||||
metadata={"category": category, "pattern": pattern},
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_emergency_controls: EmergencyControls | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_emergency_controls() -> EmergencyControls:
|
||||
"""Get the singleton EmergencyControls instance."""
|
||||
global _emergency_controls
|
||||
|
||||
async with _lock:
|
||||
if _emergency_controls is None:
|
||||
_emergency_controls = EmergencyControls()
|
||||
return _emergency_controls
|
||||
|
||||
|
||||
async def emergency_stop_global(
|
||||
reason: str,
|
||||
triggered_by: str = "system",
|
||||
) -> EmergencyEvent:
|
||||
"""Quick global emergency stop."""
|
||||
controls = await get_emergency_controls()
|
||||
return await controls.emergency_stop(
|
||||
reason=EmergencyReason.MANUAL,
|
||||
triggered_by=triggered_by,
|
||||
message=reason,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
|
||||
async def check_emergency_allowed(scope: str | None = None) -> bool:
|
||||
"""Quick check if operations are allowed."""
|
||||
controls = await get_emergency_controls()
|
||||
return await controls.check_allowed(scope=scope, raise_if_blocked=False)
|
||||
277
backend/app/services/safety/exceptions.py
Normal file
277
backend/app/services/safety/exceptions.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Safety Framework Exceptions
|
||||
|
||||
Custom exception classes for the safety and guardrails framework.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SafetyError(Exception):
|
||||
"""Base exception for all safety-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
action_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.action_id = action_id
|
||||
self.agent_id = agent_id
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class PermissionDeniedError(SafetyError):
|
||||
"""Raised when an action is not permitted."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Permission denied",
|
||||
*,
|
||||
action_type: str | None = None,
|
||||
resource: str | None = None,
|
||||
required_permission: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.action_type = action_type
|
||||
self.resource = resource
|
||||
self.required_permission = required_permission
|
||||
|
||||
|
||||
class BudgetExceededError(SafetyError):
|
||||
"""Raised when cost budget is exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Budget exceeded",
|
||||
*,
|
||||
budget_type: str = "session",
|
||||
current_usage: float = 0.0,
|
||||
budget_limit: float = 0.0,
|
||||
unit: str = "tokens",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.budget_type = budget_type
|
||||
self.current_usage = current_usage
|
||||
self.budget_limit = budget_limit
|
||||
self.unit = unit
|
||||
|
||||
|
||||
class RateLimitExceededError(SafetyError):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Rate limit exceeded",
|
||||
*,
|
||||
limit_type: str = "actions",
|
||||
limit_value: int = 0,
|
||||
window_seconds: int = 60,
|
||||
retry_after_seconds: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.limit_type = limit_type
|
||||
self.limit_value = limit_value
|
||||
self.window_seconds = window_seconds
|
||||
self.retry_after_seconds = retry_after_seconds
|
||||
|
||||
|
||||
class LoopDetectedError(SafetyError):
|
||||
"""Raised when an action loop is detected."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Loop detected",
|
||||
*,
|
||||
loop_type: str = "exact",
|
||||
repetition_count: int = 0,
|
||||
action_pattern: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.loop_type = loop_type
|
||||
self.repetition_count = repetition_count
|
||||
self.action_pattern = action_pattern or []
|
||||
|
||||
|
||||
class ApprovalRequiredError(SafetyError):
|
||||
"""Raised when human approval is required."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Human approval required",
|
||||
*,
|
||||
approval_id: str | None = None,
|
||||
reason: str | None = None,
|
||||
timeout_seconds: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.approval_id = approval_id
|
||||
self.reason = reason
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class ApprovalDeniedError(SafetyError):
|
||||
"""Raised when human explicitly denies an action."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Approval denied by human",
|
||||
*,
|
||||
approval_id: str | None = None,
|
||||
denied_by: str | None = None,
|
||||
denial_reason: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.approval_id = approval_id
|
||||
self.denied_by = denied_by
|
||||
self.denial_reason = denial_reason
|
||||
|
||||
|
||||
class ApprovalTimeoutError(SafetyError):
|
||||
"""Raised when approval request times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Approval request timed out",
|
||||
*,
|
||||
approval_id: str | None = None,
|
||||
timeout_seconds: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.approval_id = approval_id
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class RollbackError(SafetyError):
|
||||
"""Raised when rollback fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Rollback failed",
|
||||
*,
|
||||
checkpoint_id: str | None = None,
|
||||
failed_actions: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.failed_actions = failed_actions or []
|
||||
|
||||
|
||||
class CheckpointError(SafetyError):
|
||||
"""Raised when checkpoint creation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Checkpoint creation failed",
|
||||
*,
|
||||
checkpoint_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_type = checkpoint_type
|
||||
|
||||
|
||||
class ValidationError(SafetyError):
|
||||
"""Raised when action validation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Validation failed",
|
||||
*,
|
||||
validation_rules: list[str] | None = None,
|
||||
failed_rules: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.validation_rules = validation_rules or []
|
||||
self.failed_rules = failed_rules or []
|
||||
|
||||
|
||||
class ContentFilterError(SafetyError):
|
||||
"""Raised when content filtering detects prohibited content."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Prohibited content detected",
|
||||
*,
|
||||
filter_type: str | None = None,
|
||||
detected_patterns: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.filter_type = filter_type
|
||||
self.detected_patterns = detected_patterns or []
|
||||
|
||||
|
||||
class SandboxError(SafetyError):
|
||||
"""Raised when sandbox execution fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Sandbox execution failed",
|
||||
*,
|
||||
exit_code: int | None = None,
|
||||
stderr: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.exit_code = exit_code
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
class SandboxTimeoutError(SandboxError):
|
||||
"""Raised when sandbox execution times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Sandbox execution timed out",
|
||||
*,
|
||||
timeout_seconds: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class EmergencyStopError(SafetyError):
|
||||
"""Raised when emergency stop is triggered."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Emergency stop triggered",
|
||||
*,
|
||||
stop_type: str = "kill",
|
||||
triggered_by: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.stop_type = stop_type
|
||||
self.triggered_by = triggered_by
|
||||
|
||||
|
||||
class PolicyViolationError(SafetyError):
|
||||
"""Raised when an action violates a safety policy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Policy violation",
|
||||
*,
|
||||
policy_name: str | None = None,
|
||||
violated_rules: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.policy_name = policy_name
|
||||
self.violated_rules = violated_rules or []
|
||||
864
backend/app/services/safety/guardian.py
Normal file
864
backend/app/services/safety/guardian.py
Normal file
@@ -0,0 +1,864 @@
|
||||
"""
|
||||
Safety Guardian
|
||||
|
||||
Main facade for the safety framework. Orchestrates all safety checks
|
||||
before, during, and after action execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .audit import AuditLogger, get_audit_logger
|
||||
from .config import (
|
||||
SafetyConfig,
|
||||
get_policy_for_autonomy_level,
|
||||
get_safety_config,
|
||||
)
|
||||
from .costs.controller import CostController
|
||||
from .exceptions import (
|
||||
BudgetExceededError,
|
||||
LoopDetectedError,
|
||||
RateLimitExceededError,
|
||||
SafetyError,
|
||||
)
|
||||
from .limits.limiter import RateLimiter
|
||||
from .loops.detector import LoopDetector
|
||||
from .models import (
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
AuditEventType,
|
||||
BudgetScope,
|
||||
GuardianResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafetyGuardian:
|
||||
"""
|
||||
Central orchestrator for all safety checks.
|
||||
|
||||
The SafetyGuardian is the main entry point for validating agent actions.
|
||||
It coordinates multiple safety subsystems:
|
||||
- Permission checking
|
||||
- Cost/budget control
|
||||
- Rate limiting
|
||||
- Loop detection
|
||||
- Human-in-the-loop approval
|
||||
- Rollback/checkpoint management
|
||||
- Content filtering
|
||||
- Sandbox execution
|
||||
|
||||
Usage:
|
||||
guardian = SafetyGuardian()
|
||||
await guardian.initialize()
|
||||
|
||||
# Before executing an action
|
||||
result = await guardian.validate(action_request)
|
||||
if not result.allowed:
|
||||
# Handle denial
|
||||
|
||||
# After action execution
|
||||
await guardian.record_execution(action_request, action_result)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SafetyConfig | None = None,
|
||||
audit_logger: AuditLogger | None = None,
|
||||
cost_controller: CostController | None = None,
|
||||
rate_limiter: RateLimiter | None = None,
|
||||
loop_detector: LoopDetector | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the SafetyGuardian.
|
||||
|
||||
Args:
|
||||
config: Optional safety configuration. If None, loads from environment.
|
||||
audit_logger: Optional audit logger. If None, uses global instance.
|
||||
cost_controller: Optional cost controller. If None, creates default.
|
||||
rate_limiter: Optional rate limiter. If None, creates default.
|
||||
loop_detector: Optional loop detector. If None, creates default.
|
||||
"""
|
||||
self._config = config or get_safety_config()
|
||||
self._audit_logger = audit_logger
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Core safety subsystems (always initialized)
|
||||
self._cost_controller: CostController | None = cost_controller
|
||||
self._rate_limiter: RateLimiter | None = rate_limiter
|
||||
self._loop_detector: LoopDetector | None = loop_detector
|
||||
|
||||
# Optional subsystems (will be initialized when available)
|
||||
self._permission_manager: Any = None
|
||||
self._hitl_manager: Any = None
|
||||
self._rollback_manager: Any = None
|
||||
self._content_filter: Any = None
|
||||
self._sandbox_executor: Any = None
|
||||
self._emergency_controls: Any = None
|
||||
|
||||
# Policy cache
|
||||
self._policies: dict[str, SafetyPolicy] = {}
|
||||
self._default_policy: SafetyPolicy | None = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the guardian is initialized."""
|
||||
return self._initialized
|
||||
|
||||
@property
|
||||
def cost_controller(self) -> CostController | None:
|
||||
"""Get the cost controller instance."""
|
||||
return self._cost_controller
|
||||
|
||||
@property
|
||||
def rate_limiter(self) -> RateLimiter | None:
|
||||
"""Get the rate limiter instance."""
|
||||
return self._rate_limiter
|
||||
|
||||
@property
|
||||
def loop_detector(self) -> LoopDetector | None:
|
||||
"""Get the loop detector instance."""
|
||||
return self._loop_detector
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the SafetyGuardian and all subsystems."""
|
||||
async with self._lock:
|
||||
if self._initialized:
|
||||
logger.warning("SafetyGuardian already initialized")
|
||||
return
|
||||
|
||||
logger.info("Initializing SafetyGuardian")
|
||||
|
||||
# Get audit logger
|
||||
if self._audit_logger is None:
|
||||
self._audit_logger = await get_audit_logger()
|
||||
|
||||
# Initialize core safety subsystems
|
||||
if self._cost_controller is None:
|
||||
self._cost_controller = CostController()
|
||||
logger.debug("Initialized CostController")
|
||||
|
||||
if self._rate_limiter is None:
|
||||
self._rate_limiter = RateLimiter()
|
||||
logger.debug("Initialized RateLimiter")
|
||||
|
||||
if self._loop_detector is None:
|
||||
self._loop_detector = LoopDetector()
|
||||
logger.debug("Initialized LoopDetector")
|
||||
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
"SafetyGuardian initialized with CostController, RateLimiter, LoopDetector"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the SafetyGuardian and all subsystems."""
|
||||
async with self._lock:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Shutting down SafetyGuardian")
|
||||
|
||||
# Shutdown subsystems
|
||||
# (Will be implemented as subsystems are added)
|
||||
|
||||
self._initialized = False
|
||||
logger.info("SafetyGuardian shutdown complete")
|
||||
|
||||
async def validate(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy | None = None,
|
||||
) -> GuardianResult:
|
||||
"""
|
||||
Validate an action before execution.
|
||||
|
||||
Runs all safety checks in order:
|
||||
1. Permission check
|
||||
2. Cost/budget check
|
||||
3. Rate limit check
|
||||
4. Loop detection
|
||||
5. HITL check (if required)
|
||||
6. Checkpoint creation (if destructive)
|
||||
|
||||
Args:
|
||||
action: The action to validate
|
||||
policy: Optional policy override. If None, uses autonomy-level policy.
|
||||
|
||||
Returns:
|
||||
GuardianResult with decision and details
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not self._config.enabled:
|
||||
# Safety disabled - allow everything (NOT RECOMMENDED)
|
||||
logger.warning("Safety framework disabled - allowing action %s", action.id)
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Safety framework disabled"],
|
||||
)
|
||||
|
||||
# Get policy for this action
|
||||
effective_policy = policy or self._get_policy(action)
|
||||
|
||||
reasons: list[str] = []
|
||||
audit_events = []
|
||||
|
||||
try:
|
||||
# Log action request
|
||||
if self._audit_logger:
|
||||
event = await self._audit_logger.log(
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
audit_events.append(event)
|
||||
|
||||
# 1. Permission check
|
||||
permission_result = await self._check_permissions(action, effective_policy)
|
||||
if permission_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action, permission_result.reasons, audit_events
|
||||
)
|
||||
|
||||
# 2. Cost/budget check
|
||||
budget_result = await self._check_budget(action, effective_policy)
|
||||
if budget_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action, budget_result.reasons, audit_events
|
||||
)
|
||||
|
||||
# 3. Rate limit check
|
||||
rate_result = await self._check_rate_limit(action, effective_policy)
|
||||
if rate_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action,
|
||||
rate_result.reasons,
|
||||
audit_events,
|
||||
retry_after=rate_result.retry_after_seconds,
|
||||
)
|
||||
if rate_result.decision == SafetyDecision.DELAY:
|
||||
# Return delay decision
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DELAY,
|
||||
reasons=rate_result.reasons,
|
||||
retry_after_seconds=rate_result.retry_after_seconds,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
# 4. Loop detection
|
||||
loop_result = await self._check_loops(action, effective_policy)
|
||||
if loop_result.decision == SafetyDecision.DENY:
|
||||
return await self._create_denial_result(
|
||||
action, loop_result.reasons, audit_events
|
||||
)
|
||||
|
||||
# 5. HITL check
|
||||
hitl_result = await self._check_hitl(action, effective_policy)
|
||||
if hitl_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=hitl_result.reasons,
|
||||
approval_id=hitl_result.approval_id,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
# 6. Create checkpoint if destructive
|
||||
checkpoint_id = None
|
||||
if action.is_destructive and self._config.auto_checkpoint_destructive:
|
||||
checkpoint_id = await self._create_checkpoint(action)
|
||||
|
||||
# All checks passed
|
||||
reasons.append("All safety checks passed")
|
||||
|
||||
if self._audit_logger:
|
||||
event = await self._audit_logger.log_action_request(
|
||||
action, SafetyDecision.ALLOW, reasons
|
||||
)
|
||||
audit_events.append(event)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=reasons,
|
||||
checkpoint_id=checkpoint_id,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
except SafetyError as e:
|
||||
# Known safety error
|
||||
return await self._create_denial_result(action, [str(e)], audit_events)
|
||||
except Exception as e:
|
||||
# Unknown error - fail closed in strict mode
|
||||
logger.error("Unexpected error in safety validation: %s", e)
|
||||
if self._config.strict_mode:
|
||||
return await self._create_denial_result(
|
||||
action,
|
||||
[f"Safety validation error: {e}"],
|
||||
audit_events,
|
||||
)
|
||||
else:
|
||||
# Non-strict mode - allow with warning
|
||||
logger.warning("Non-strict mode: allowing action despite error")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Allowed despite validation error (non-strict mode)"],
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
async def record_execution(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
result: ActionResult,
|
||||
) -> None:
|
||||
"""
|
||||
Record action execution result for auditing and tracking.
|
||||
|
||||
Args:
|
||||
action: The executed action
|
||||
result: The execution result
|
||||
"""
|
||||
if self._audit_logger:
|
||||
await self._audit_logger.log_action_executed(
|
||||
action,
|
||||
success=result.success,
|
||||
execution_time_ms=result.execution_time_ms,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
# Update cost tracking
|
||||
if self._cost_controller:
|
||||
try:
|
||||
# Use explicit None check - 0 is a valid cost value
|
||||
tokens = (
|
||||
result.actual_cost_tokens
|
||||
if result.actual_cost_tokens is not None
|
||||
else action.estimated_cost_tokens
|
||||
)
|
||||
cost_usd = (
|
||||
result.actual_cost_usd
|
||||
if result.actual_cost_usd is not None
|
||||
else action.estimated_cost_usd
|
||||
)
|
||||
await self._cost_controller.record_usage(
|
||||
agent_id=action.metadata.agent_id,
|
||||
session_id=action.metadata.session_id,
|
||||
tokens=tokens,
|
||||
cost_usd=cost_usd,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record cost: %s", e)
|
||||
|
||||
# Update rate limiter - consume slots for executed actions
|
||||
if self._rate_limiter:
|
||||
try:
|
||||
await self._rate_limiter.record_action(action)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record action in rate limiter: %s", e)
|
||||
|
||||
# Update loop detection history
|
||||
if self._loop_detector:
|
||||
try:
|
||||
await self._loop_detector.record(action)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record action in loop detector: %s", e)
|
||||
|
||||
async def rollback(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Rollback to a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint to rollback to
|
||||
|
||||
Returns:
|
||||
True if rollback succeeded
|
||||
"""
|
||||
if self._rollback_manager is None:
|
||||
logger.warning("Rollback manager not available")
|
||||
return False
|
||||
|
||||
# Delegate to rollback manager
|
||||
return await self._rollback_manager.rollback(checkpoint_id)
|
||||
|
||||
async def emergency_stop(
|
||||
self,
|
||||
stop_type: str = "kill",
|
||||
reason: str = "Manual emergency stop",
|
||||
triggered_by: str = "system",
|
||||
) -> None:
|
||||
"""
|
||||
Trigger emergency stop.
|
||||
|
||||
Args:
|
||||
stop_type: Type of stop (kill, pause, lockdown)
|
||||
reason: Reason for the stop
|
||||
triggered_by: Who triggered the stop
|
||||
"""
|
||||
logger.critical(
|
||||
"Emergency stop triggered: type=%s, reason=%s, by=%s",
|
||||
stop_type,
|
||||
reason,
|
||||
triggered_by,
|
||||
)
|
||||
|
||||
if self._audit_logger:
|
||||
await self._audit_logger.log_emergency_stop(
|
||||
stop_type=stop_type,
|
||||
triggered_by=triggered_by,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if self._emergency_controls:
|
||||
await self._emergency_controls.execute_stop(stop_type)
|
||||
|
||||
def _get_policy(self, action: ActionRequest) -> SafetyPolicy:
|
||||
"""Get the effective policy for an action."""
|
||||
# Check cached policies
|
||||
autonomy_level = action.metadata.autonomy_level
|
||||
|
||||
if autonomy_level.value not in self._policies:
|
||||
self._policies[autonomy_level.value] = get_policy_for_autonomy_level(
|
||||
autonomy_level
|
||||
)
|
||||
|
||||
return self._policies[autonomy_level.value]
|
||||
|
||||
async def _check_permissions(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is permitted."""
|
||||
reasons: list[str] = []
|
||||
|
||||
# Check denied tools
|
||||
if action.tool_name:
|
||||
for pattern in policy.denied_tools:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
reasons.append(
|
||||
f"Tool '{action.tool_name}' denied by pattern '{pattern}'"
|
||||
)
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
# Check allowed tools (if not "*")
|
||||
if action.tool_name and "*" not in policy.allowed_tools:
|
||||
allowed = False
|
||||
for pattern in policy.allowed_tools:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
allowed = True
|
||||
break
|
||||
if not allowed:
|
||||
reasons.append(f"Tool '{action.tool_name}' not in allowed list")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
# Check file patterns
|
||||
if action.resource:
|
||||
for pattern in policy.denied_file_patterns:
|
||||
if self._matches_pattern(action.resource, pattern):
|
||||
reasons.append(
|
||||
f"Resource '{action.resource}' denied by pattern '{pattern}'"
|
||||
)
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Permission check passed"],
|
||||
)
|
||||
|
||||
async def _check_budget(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is within budget."""
|
||||
if self._cost_controller is None:
|
||||
logger.warning("CostController not initialized - skipping budget check")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Budget check skipped (controller not initialized)"],
|
||||
)
|
||||
|
||||
agent_id = action.metadata.agent_id
|
||||
session_id = action.metadata.session_id
|
||||
|
||||
try:
|
||||
# Check if we have budget for this action
|
||||
has_budget = await self._cost_controller.check_budget(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
estimated_tokens=action.estimated_cost_tokens,
|
||||
estimated_cost_usd=action.estimated_cost_usd,
|
||||
)
|
||||
|
||||
if not has_budget:
|
||||
# Get current status for better error message
|
||||
if session_id:
|
||||
session_status = await self._cost_controller.get_status(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
if session_status and session_status.is_exceeded:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[
|
||||
f"Session budget exceeded: {session_status.tokens_used}"
|
||||
f"/{session_status.tokens_limit} tokens"
|
||||
],
|
||||
)
|
||||
|
||||
agent_status = await self._cost_controller.get_status(
|
||||
BudgetScope.DAILY, agent_id
|
||||
)
|
||||
if agent_status and agent_status.is_exceeded:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[
|
||||
f"Daily budget exceeded: {agent_status.tokens_used}"
|
||||
f"/{agent_status.tokens_limit} tokens"
|
||||
],
|
||||
)
|
||||
|
||||
# Generic budget exceeded
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=["Budget exceeded"],
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Budget check passed"],
|
||||
)
|
||||
|
||||
except BudgetExceededError as e:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[str(e)],
|
||||
)
|
||||
|
||||
async def _check_rate_limit(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if action is within rate limits."""
|
||||
if self._rate_limiter is None:
|
||||
logger.warning("RateLimiter not initialized - skipping rate limit check")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Rate limit check skipped (limiter not initialized)"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Check all applicable rate limits for this action
|
||||
allowed, statuses = await self._rate_limiter.check_action(action)
|
||||
|
||||
if not allowed:
|
||||
# Find the first exceeded limit for the error message
|
||||
exceeded_status = next(
|
||||
(s for s in statuses if s.is_limited),
|
||||
statuses[0] if statuses else None,
|
||||
)
|
||||
|
||||
if exceeded_status:
|
||||
retry_after = exceeded_status.retry_after_seconds
|
||||
|
||||
# Determine if this is a soft limit (delay) or hard limit (deny)
|
||||
if retry_after > 0 and retry_after <= 5.0:
|
||||
# Short wait - suggest delay
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DELAY,
|
||||
reasons=[
|
||||
f"Rate limit '{exceeded_status.name}' exceeded. "
|
||||
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}"
|
||||
],
|
||||
retry_after_seconds=retry_after,
|
||||
)
|
||||
else:
|
||||
# Hard deny
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[
|
||||
f"Rate limit '{exceeded_status.name}' exceeded. "
|
||||
f"Current: {exceeded_status.current_count}/{exceeded_status.limit}. "
|
||||
f"Retry after {retry_after:.1f}s"
|
||||
],
|
||||
retry_after_seconds=retry_after,
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=["Rate limit exceeded"],
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Rate limit check passed"],
|
||||
)
|
||||
|
||||
except RateLimitExceededError as e:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[str(e)],
|
||||
retry_after_seconds=e.retry_after_seconds,
|
||||
)
|
||||
|
||||
async def _check_loops(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check for action loops."""
|
||||
if self._loop_detector is None:
|
||||
logger.warning("LoopDetector not initialized - skipping loop check")
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Loop check skipped (detector not initialized)"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if this action would create a loop
|
||||
is_loop, loop_type = await self._loop_detector.check(action)
|
||||
|
||||
if is_loop:
|
||||
# Get suggestions for breaking the loop
|
||||
from .loops.detector import LoopBreaker
|
||||
|
||||
suggestions = await LoopBreaker.suggest_alternatives(
|
||||
action, loop_type or "unknown"
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[
|
||||
f"Loop detected: {loop_type}",
|
||||
*suggestions,
|
||||
],
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["Loop check passed"],
|
||||
)
|
||||
|
||||
except LoopDetectedError as e:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=[str(e)],
|
||||
)
|
||||
|
||||
async def _check_hitl(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy,
|
||||
) -> GuardianResult:
|
||||
"""Check if human approval is required."""
|
||||
if not self._config.hitl_enabled:
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["HITL disabled"],
|
||||
)
|
||||
|
||||
# Check if action requires approval
|
||||
requires_approval = False
|
||||
for pattern in policy.require_approval_for:
|
||||
if pattern == "*":
|
||||
requires_approval = True
|
||||
break
|
||||
if action.tool_name and self._matches_pattern(action.tool_name, pattern):
|
||||
requires_approval = True
|
||||
break
|
||||
if action.action_type.value and self._matches_pattern(
|
||||
action.action_type.value, pattern
|
||||
):
|
||||
requires_approval = True
|
||||
break
|
||||
|
||||
if requires_approval:
|
||||
# TODO: Create approval request with HITLManager
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reasons=["Action requires human approval"],
|
||||
approval_id=None, # Will be set by HITLManager
|
||||
)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=True,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
reasons=["No approval required"],
|
||||
)
|
||||
|
||||
async def _create_checkpoint(self, action: ActionRequest) -> str | None:
|
||||
"""Create a checkpoint before destructive action."""
|
||||
if self._rollback_manager is None:
|
||||
logger.warning("Rollback manager not available - skipping checkpoint")
|
||||
return None
|
||||
|
||||
# TODO: Implement with RollbackManager
|
||||
return None
|
||||
|
||||
async def _create_denial_result(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
reasons: list[str],
|
||||
audit_events: list[Any],
|
||||
retry_after: float | None = None,
|
||||
) -> GuardianResult:
|
||||
"""Create a denial result with audit logging."""
|
||||
if self._audit_logger:
|
||||
event = await self._audit_logger.log_action_request(
|
||||
action, SafetyDecision.DENY, reasons
|
||||
)
|
||||
audit_events.append(event)
|
||||
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
decision=SafetyDecision.DENY,
|
||||
reasons=reasons,
|
||||
retry_after_seconds=retry_after,
|
||||
audit_events=audit_events,
|
||||
)
|
||||
|
||||
def _matches_pattern(self, value: str, pattern: str) -> bool:
|
||||
"""Check if value matches a pattern (supports * wildcard)."""
|
||||
if pattern == "*":
|
||||
return True
|
||||
|
||||
if "*" not in pattern:
|
||||
return value == pattern
|
||||
|
||||
# Simple wildcard matching
|
||||
if pattern.startswith("*") and pattern.endswith("*"):
|
||||
return pattern[1:-1] in value
|
||||
elif pattern.startswith("*"):
|
||||
return value.endswith(pattern[1:])
|
||||
elif pattern.endswith("*"):
|
||||
return value.startswith(pattern[:-1])
|
||||
else:
|
||||
# Pattern like "foo*bar"
|
||||
parts = pattern.split("*")
|
||||
if len(parts) == 2:
|
||||
return value.startswith(parts[0]) and value.endswith(parts[1])
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_guardian_instance: SafetyGuardian | None = None
|
||||
_guardian_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_safety_guardian() -> SafetyGuardian:
|
||||
"""Get the global SafetyGuardian instance."""
|
||||
global _guardian_instance
|
||||
|
||||
async with _guardian_lock:
|
||||
if _guardian_instance is None:
|
||||
_guardian_instance = SafetyGuardian()
|
||||
await _guardian_instance.initialize()
|
||||
|
||||
return _guardian_instance
|
||||
|
||||
|
||||
async def shutdown_safety_guardian() -> None:
|
||||
"""Shutdown the global SafetyGuardian."""
|
||||
global _guardian_instance
|
||||
|
||||
async with _guardian_lock:
|
||||
if _guardian_instance is not None:
|
||||
await _guardian_instance.shutdown()
|
||||
_guardian_instance = None
|
||||
|
||||
|
||||
async def reset_safety_guardian() -> None:
|
||||
"""
|
||||
Reset the SafetyGuardian (for testing).
|
||||
|
||||
This is an async function to properly acquire the guardian lock
|
||||
and avoid race conditions with get_safety_guardian().
|
||||
"""
|
||||
global _guardian_instance
|
||||
|
||||
async with _guardian_lock:
|
||||
if _guardian_instance is not None:
|
||||
try:
|
||||
await _guardian_instance.shutdown()
|
||||
except Exception: # noqa: S110
|
||||
pass # Ignore errors during test cleanup
|
||||
_guardian_instance = None
|
||||
5
backend/app/services/safety/hitl/__init__.py
Normal file
5
backend/app/services/safety/hitl/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Human-in-the-Loop approval workflows."""
|
||||
|
||||
from .manager import ApprovalQueue, HITLManager
|
||||
|
||||
__all__ = ["ApprovalQueue", "HITLManager"]
|
||||
449
backend/app/services/safety/hitl/manager.py
Normal file
449
backend/app/services/safety/hitl/manager.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
Human-in-the-Loop (HITL) Manager
|
||||
|
||||
Manages approval workflows for actions requiring human oversight.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import (
|
||||
ApprovalDeniedError,
|
||||
ApprovalRequiredError,
|
||||
ApprovalTimeoutError,
|
||||
)
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
ApprovalRequest,
|
||||
ApprovalResponse,
|
||||
ApprovalStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApprovalQueue:
|
||||
"""Queue for pending approval requests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending: dict[str, ApprovalRequest] = {}
|
||||
self._completed: dict[str, ApprovalResponse] = {}
|
||||
self._waiters: dict[str, asyncio.Event] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add(self, request: ApprovalRequest) -> None:
|
||||
"""Add an approval request to the queue."""
|
||||
async with self._lock:
|
||||
self._pending[request.id] = request
|
||||
self._waiters[request.id] = asyncio.Event()
|
||||
|
||||
async def get_pending(self, request_id: str) -> ApprovalRequest | None:
|
||||
"""Get a pending request by ID."""
|
||||
async with self._lock:
|
||||
return self._pending.get(request_id)
|
||||
|
||||
async def complete(self, response: ApprovalResponse) -> bool:
|
||||
"""Complete an approval request."""
|
||||
async with self._lock:
|
||||
if response.request_id not in self._pending:
|
||||
return False
|
||||
|
||||
del self._pending[response.request_id]
|
||||
self._completed[response.request_id] = response
|
||||
|
||||
# Notify waiters
|
||||
if response.request_id in self._waiters:
|
||||
self._waiters[response.request_id].set()
|
||||
|
||||
return True
|
||||
|
||||
async def wait_for_response(
|
||||
self,
|
||||
request_id: str,
|
||||
timeout_seconds: float,
|
||||
) -> ApprovalResponse | None:
|
||||
"""Wait for a response to an approval request."""
|
||||
async with self._lock:
|
||||
waiter = self._waiters.get(request_id)
|
||||
if not waiter:
|
||||
return self._completed.get(request_id)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(waiter.wait(), timeout=timeout_seconds)
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
async with self._lock:
|
||||
return self._completed.get(request_id)
|
||||
|
||||
async def list_pending(self) -> list[ApprovalRequest]:
|
||||
"""List all pending requests."""
|
||||
async with self._lock:
|
||||
return list(self._pending.values())
|
||||
|
||||
async def cancel(self, request_id: str) -> bool:
|
||||
"""Cancel a pending request."""
|
||||
async with self._lock:
|
||||
if request_id not in self._pending:
|
||||
return False
|
||||
|
||||
del self._pending[request_id]
|
||||
|
||||
# Create cancelled response
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.CANCELLED,
|
||||
reason="Cancelled",
|
||||
)
|
||||
self._completed[request_id] = response
|
||||
|
||||
# Notify waiters
|
||||
if request_id in self._waiters:
|
||||
self._waiters[request_id].set()
|
||||
|
||||
return True
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired requests."""
|
||||
now = datetime.utcnow()
|
||||
to_timeout: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
for request_id, request in self._pending.items():
|
||||
if request.expires_at and request.expires_at < now:
|
||||
to_timeout.append(request_id)
|
||||
|
||||
count = 0
|
||||
for request_id in to_timeout:
|
||||
async with self._lock:
|
||||
if request_id in self._pending:
|
||||
del self._pending[request_id]
|
||||
self._completed[request_id] = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.TIMEOUT,
|
||||
reason="Request timed out",
|
||||
)
|
||||
if request_id in self._waiters:
|
||||
self._waiters[request_id].set()
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class HITLManager:
|
||||
"""
|
||||
Manages Human-in-the-Loop approval workflows.
|
||||
|
||||
Features:
|
||||
- Approval request queue
|
||||
- Configurable timeout handling (default deny)
|
||||
- Approval delegation
|
||||
- Batch approval for similar actions
|
||||
- Approval with modifications
|
||||
- Notification channels
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_timeout: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the HITLManager.
|
||||
|
||||
Args:
|
||||
default_timeout: Default timeout for approval requests in seconds
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._default_timeout = default_timeout or config.hitl_default_timeout
|
||||
self._queue = ApprovalQueue()
|
||||
self._notification_handlers: list[Callable[..., Any]] = []
|
||||
self._running = False
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the HITL manager background tasks."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
logger.info("HITL Manager started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the HITL manager."""
|
||||
self._running = False
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("HITL Manager stopped")
|
||||
|
||||
async def request_approval(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
reason: str,
|
||||
timeout_seconds: int | None = None,
|
||||
urgency: str = "normal",
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> ApprovalRequest:
|
||||
"""
|
||||
Create an approval request for an action.
|
||||
|
||||
Args:
|
||||
action: The action requiring approval
|
||||
reason: Why approval is required
|
||||
timeout_seconds: Timeout for this request
|
||||
urgency: Urgency level (low, normal, high, critical)
|
||||
context: Additional context for the approver
|
||||
|
||||
Returns:
|
||||
The created approval request
|
||||
"""
|
||||
timeout = timeout_seconds or self._default_timeout
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=timeout)
|
||||
|
||||
request = ApprovalRequest(
|
||||
id=str(uuid4()),
|
||||
action=action,
|
||||
reason=reason,
|
||||
urgency=urgency,
|
||||
timeout_seconds=timeout,
|
||||
expires_at=expires_at,
|
||||
context=context or {},
|
||||
)
|
||||
|
||||
await self._queue.add(request)
|
||||
|
||||
# Notify handlers
|
||||
await self._notify_handlers("approval_requested", request)
|
||||
|
||||
logger.info(
|
||||
"Approval requested: %s for action %s (timeout: %ds)",
|
||||
request.id,
|
||||
action.id,
|
||||
timeout,
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
async def wait_for_approval(
|
||||
self,
|
||||
request_id: str,
|
||||
timeout_seconds: int | None = None,
|
||||
) -> ApprovalResponse:
|
||||
"""
|
||||
Wait for an approval decision.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
timeout_seconds: Override timeout
|
||||
|
||||
Returns:
|
||||
The approval response
|
||||
|
||||
Raises:
|
||||
ApprovalTimeoutError: If timeout expires
|
||||
ApprovalDeniedError: If approval is denied
|
||||
"""
|
||||
request = await self._queue.get_pending(request_id)
|
||||
if not request:
|
||||
raise ApprovalRequiredError(
|
||||
f"Approval request not found: {request_id}",
|
||||
approval_id=request_id,
|
||||
)
|
||||
|
||||
timeout = timeout_seconds or request.timeout_seconds or self._default_timeout
|
||||
response = await self._queue.wait_for_response(request_id, timeout)
|
||||
|
||||
if response is None:
|
||||
# Timeout - default deny
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.TIMEOUT,
|
||||
reason="Request timed out (default deny)",
|
||||
)
|
||||
await self._queue.complete(response)
|
||||
|
||||
raise ApprovalTimeoutError(
|
||||
"Approval request timed out",
|
||||
approval_id=request_id,
|
||||
timeout_seconds=timeout,
|
||||
)
|
||||
|
||||
if response.status == ApprovalStatus.DENIED:
|
||||
raise ApprovalDeniedError(
|
||||
response.reason or "Approval denied",
|
||||
approval_id=request_id,
|
||||
denied_by=response.decided_by,
|
||||
denial_reason=response.reason,
|
||||
)
|
||||
|
||||
if response.status == ApprovalStatus.TIMEOUT:
|
||||
raise ApprovalTimeoutError(
|
||||
"Approval request timed out",
|
||||
approval_id=request_id,
|
||||
timeout_seconds=timeout,
|
||||
)
|
||||
|
||||
if response.status == ApprovalStatus.CANCELLED:
|
||||
raise ApprovalDeniedError(
|
||||
"Approval request was cancelled",
|
||||
approval_id=request_id,
|
||||
denial_reason="Cancelled",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def approve(
|
||||
self,
|
||||
request_id: str,
|
||||
decided_by: str,
|
||||
reason: str | None = None,
|
||||
modifications: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Approve a pending request.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
decided_by: Who approved
|
||||
reason: Optional approval reason
|
||||
modifications: Optional modifications to the action
|
||||
|
||||
Returns:
|
||||
True if approval was recorded
|
||||
"""
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.APPROVED,
|
||||
decided_by=decided_by,
|
||||
reason=reason,
|
||||
modifications=modifications,
|
||||
)
|
||||
|
||||
success = await self._queue.complete(response)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Approval granted: %s by %s",
|
||||
request_id,
|
||||
decided_by,
|
||||
)
|
||||
await self._notify_handlers("approval_granted", response)
|
||||
|
||||
return success
|
||||
|
||||
async def deny(
|
||||
self,
|
||||
request_id: str,
|
||||
decided_by: str,
|
||||
reason: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Deny a pending request.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
decided_by: Who denied
|
||||
reason: Denial reason
|
||||
|
||||
Returns:
|
||||
True if denial was recorded
|
||||
"""
|
||||
response = ApprovalResponse(
|
||||
request_id=request_id,
|
||||
status=ApprovalStatus.DENIED,
|
||||
decided_by=decided_by,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
success = await self._queue.complete(response)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Approval denied: %s by %s - %s",
|
||||
request_id,
|
||||
decided_by,
|
||||
reason,
|
||||
)
|
||||
await self._notify_handlers("approval_denied", response)
|
||||
|
||||
return success
|
||||
|
||||
async def cancel(self, request_id: str) -> bool:
|
||||
"""
|
||||
Cancel a pending request.
|
||||
|
||||
Args:
|
||||
request_id: ID of the approval request
|
||||
|
||||
Returns:
|
||||
True if request was cancelled
|
||||
"""
|
||||
success = await self._queue.cancel(request_id)
|
||||
|
||||
if success:
|
||||
logger.info("Approval request cancelled: %s", request_id)
|
||||
|
||||
return success
|
||||
|
||||
async def list_pending(self) -> list[ApprovalRequest]:
|
||||
"""List all pending approval requests."""
|
||||
return await self._queue.list_pending()
|
||||
|
||||
async def get_request(self, request_id: str) -> ApprovalRequest | None:
|
||||
"""Get an approval request by ID."""
|
||||
return await self._queue.get_pending(request_id)
|
||||
|
||||
def add_notification_handler(
|
||||
self,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Add a notification handler."""
|
||||
self._notification_handlers.append(handler)
|
||||
|
||||
def remove_notification_handler(
|
||||
self,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Remove a notification handler."""
|
||||
if handler in self._notification_handlers:
|
||||
self._notification_handlers.remove(handler)
|
||||
|
||||
async def _notify_handlers(
|
||||
self,
|
||||
event_type: str,
|
||||
data: Any,
|
||||
) -> None:
|
||||
"""Notify all handlers of an event."""
|
||||
for handler in self._notification_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event_type, data)
|
||||
else:
|
||||
handler(event_type, data)
|
||||
except Exception as e:
|
||||
logger.error("Error in notification handler: %s", e)
|
||||
|
||||
async def _periodic_cleanup(self) -> None:
|
||||
"""Background task for cleaning up expired requests."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
count = await self._queue.cleanup_expired()
|
||||
if count:
|
||||
logger.debug("Cleaned up %d expired approval requests", count)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in approval cleanup: %s", e)
|
||||
15
backend/app/services/safety/limits/__init__.py
Normal file
15
backend/app/services/safety/limits/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Rate Limiting Module
|
||||
|
||||
Sliding window rate limiting for agent operations.
|
||||
"""
|
||||
|
||||
from .limiter import (
|
||||
RateLimiter,
|
||||
SlidingWindowCounter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RateLimiter",
|
||||
"SlidingWindowCounter",
|
||||
]
|
||||
396
backend/app/services/safety/limits/limiter.py
Normal file
396
backend/app/services/safety/limits/limiter.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
Rate Limiter
|
||||
|
||||
Sliding window rate limiting for agent operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import RateLimitExceededError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
RateLimitConfig,
|
||||
RateLimitStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""Sliding window counter for rate limiting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
burst_limit: int | None = None,
|
||||
) -> None:
|
||||
self.limit = limit
|
||||
self.window_seconds = window_seconds
|
||||
self.burst_limit = burst_limit or limit
|
||||
self._timestamps: deque[float] = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def try_acquire(self) -> tuple[bool, float]:
|
||||
"""
|
||||
Try to acquire a slot.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, retry_after_seconds)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window_seconds
|
||||
|
||||
async with self._lock:
|
||||
# Remove expired entries
|
||||
while self._timestamps and self._timestamps[0] < window_start:
|
||||
self._timestamps.popleft()
|
||||
|
||||
current_count = len(self._timestamps)
|
||||
|
||||
# Check burst limit (instant check)
|
||||
if current_count >= self.burst_limit:
|
||||
# Calculate retry time
|
||||
oldest = self._timestamps[0] if self._timestamps else now
|
||||
retry_after = oldest + self.window_seconds - now
|
||||
return False, max(0, retry_after)
|
||||
|
||||
# Check window limit
|
||||
if current_count >= self.limit:
|
||||
oldest = self._timestamps[0] if self._timestamps else now
|
||||
retry_after = oldest + self.window_seconds - now
|
||||
return False, max(0, retry_after)
|
||||
|
||||
# Allow and record
|
||||
self._timestamps.append(now)
|
||||
return True, 0.0
|
||||
|
||||
async def get_status(self) -> tuple[int, int, float]:
|
||||
"""
|
||||
Get current status.
|
||||
|
||||
Returns:
|
||||
Tuple of (current_count, remaining, reset_in_seconds)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window_seconds
|
||||
|
||||
async with self._lock:
|
||||
# Remove expired entries
|
||||
while self._timestamps and self._timestamps[0] < window_start:
|
||||
self._timestamps.popleft()
|
||||
|
||||
current_count = len(self._timestamps)
|
||||
remaining = max(0, self.limit - current_count)
|
||||
|
||||
if self._timestamps:
|
||||
reset_in = self._timestamps[0] + self.window_seconds - now
|
||||
else:
|
||||
reset_in = 0.0
|
||||
|
||||
return current_count, remaining, max(0, reset_in)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Rate limiter for agent operations.
|
||||
|
||||
Features:
|
||||
- Per-tool rate limits
|
||||
- Per-agent rate limits
|
||||
- Per-resource rate limits
|
||||
- Sliding window implementation
|
||||
- Burst allowance with recovery
|
||||
- Slowdown before hard block
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the RateLimiter."""
|
||||
config = get_safety_config()
|
||||
|
||||
self._configs: dict[str, RateLimitConfig] = {}
|
||||
self._counters: dict[str, SlidingWindowCounter] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Default rate limits
|
||||
self._default_limits = {
|
||||
"actions": RateLimitConfig(
|
||||
name="actions",
|
||||
limit=config.default_actions_per_minute,
|
||||
window_seconds=60,
|
||||
),
|
||||
"llm_calls": RateLimitConfig(
|
||||
name="llm_calls",
|
||||
limit=config.default_llm_calls_per_minute,
|
||||
window_seconds=60,
|
||||
),
|
||||
"file_ops": RateLimitConfig(
|
||||
name="file_ops",
|
||||
limit=config.default_file_ops_per_minute,
|
||||
window_seconds=60,
|
||||
),
|
||||
}
|
||||
|
||||
def configure(self, config: RateLimitConfig) -> None:
|
||||
"""
|
||||
Configure a rate limit.
|
||||
|
||||
Args:
|
||||
config: Rate limit configuration
|
||||
"""
|
||||
self._configs[config.name] = config
|
||||
logger.debug(
|
||||
"Configured rate limit: %s = %d/%ds",
|
||||
config.name,
|
||||
config.limit,
|
||||
config.window_seconds,
|
||||
)
|
||||
|
||||
async def check(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> RateLimitStatus:
|
||||
"""
|
||||
Check rate limit without consuming a slot.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking (e.g., agent_id)
|
||||
|
||||
Returns:
|
||||
Rate limit status
|
||||
"""
|
||||
counter = await self._get_counter(limit_name, key)
|
||||
config = self._get_config(limit_name)
|
||||
|
||||
current, remaining, reset_in = await counter.get_status()
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
return RateLimitStatus(
|
||||
name=limit_name,
|
||||
current_count=current,
|
||||
limit=config.limit,
|
||||
window_seconds=config.window_seconds,
|
||||
remaining=remaining,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
|
||||
is_limited=remaining <= 0,
|
||||
retry_after_seconds=reset_in if remaining <= 0 else 0.0,
|
||||
)
|
||||
|
||||
async def acquire(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> tuple[bool, RateLimitStatus]:
|
||||
"""
|
||||
Try to acquire a rate limit slot.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking (e.g., agent_id)
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, status)
|
||||
"""
|
||||
counter = await self._get_counter(limit_name, key)
|
||||
config = self._get_config(limit_name)
|
||||
|
||||
allowed, retry_after = await counter.try_acquire()
|
||||
current, remaining, reset_in = await counter.get_status()
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
status = RateLimitStatus(
|
||||
name=limit_name,
|
||||
current_count=current,
|
||||
limit=config.limit,
|
||||
window_seconds=config.window_seconds,
|
||||
remaining=remaining,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=reset_in),
|
||||
is_limited=not allowed,
|
||||
retry_after_seconds=retry_after,
|
||||
)
|
||||
|
||||
return allowed, status
|
||||
|
||||
async def check_action(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
) -> tuple[bool, list[RateLimitStatus]]:
|
||||
"""
|
||||
Check all applicable rate limits for an action WITHOUT consuming slots.
|
||||
|
||||
Use this during validation to check if action would be allowed.
|
||||
Call record_action() after successful execution to consume slots.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, list of statuses)
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
statuses: list[RateLimitStatus] = []
|
||||
allowed = True
|
||||
|
||||
# Check general actions limit (read-only)
|
||||
actions_status = await self.check("actions", agent_id)
|
||||
statuses.append(actions_status)
|
||||
if actions_status.is_limited:
|
||||
allowed = False
|
||||
|
||||
# Check LLM-specific limit for LLM calls
|
||||
if action.action_type.value == "llm_call":
|
||||
llm_status = await self.check("llm_calls", agent_id)
|
||||
statuses.append(llm_status)
|
||||
if llm_status.is_limited:
|
||||
allowed = False
|
||||
|
||||
# Check file ops limit for file operations
|
||||
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
|
||||
file_status = await self.check("file_ops", agent_id)
|
||||
statuses.append(file_status)
|
||||
if file_status.is_limited:
|
||||
allowed = False
|
||||
|
||||
return allowed, statuses
|
||||
|
||||
async def record_action(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Record an action by consuming rate limit slots.
|
||||
|
||||
Call this AFTER successful execution to properly count the action.
|
||||
|
||||
Args:
|
||||
action: The executed action
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
|
||||
# Consume general actions slot
|
||||
await self.acquire("actions", agent_id)
|
||||
|
||||
# Consume LLM-specific slot for LLM calls
|
||||
if action.action_type.value == "llm_call":
|
||||
await self.acquire("llm_calls", agent_id)
|
||||
|
||||
# Consume file ops slot for file operations
|
||||
if action.action_type.value in {"file_read", "file_write", "file_delete"}:
|
||||
await self.acquire("file_ops", agent_id)
|
||||
|
||||
async def require(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> None:
|
||||
"""
|
||||
Require rate limit slot or raise exception.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking
|
||||
|
||||
Raises:
|
||||
RateLimitExceededError: If rate limit exceeded
|
||||
"""
|
||||
allowed, status = await self.acquire(limit_name, key)
|
||||
if not allowed:
|
||||
raise RateLimitExceededError(
|
||||
f"Rate limit exceeded: {limit_name}",
|
||||
limit_type=limit_name,
|
||||
limit_value=status.limit,
|
||||
window_seconds=status.window_seconds,
|
||||
retry_after_seconds=status.retry_after_seconds,
|
||||
)
|
||||
|
||||
async def get_all_statuses(self, key: str) -> dict[str, RateLimitStatus]:
|
||||
"""
|
||||
Get status of all rate limits for a key.
|
||||
|
||||
Args:
|
||||
key: Key for tracking
|
||||
|
||||
Returns:
|
||||
Dict of limit name to status
|
||||
"""
|
||||
statuses = {}
|
||||
for name in self._default_limits:
|
||||
statuses[name] = await self.check(name, key)
|
||||
for name in self._configs:
|
||||
if name not in statuses:
|
||||
statuses[name] = await self.check(name, key)
|
||||
return statuses
|
||||
|
||||
async def reset(self, limit_name: str, key: str) -> bool:
|
||||
"""
|
||||
Reset a rate limit counter.
|
||||
|
||||
Args:
|
||||
limit_name: Name of the rate limit
|
||||
key: Key for tracking
|
||||
|
||||
Returns:
|
||||
True if counter was found and reset
|
||||
"""
|
||||
counter_key = f"{limit_name}:{key}"
|
||||
async with self._lock:
|
||||
if counter_key in self._counters:
|
||||
del self._counters[counter_key]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def reset_all(self, key: str) -> int:
|
||||
"""
|
||||
Reset all rate limit counters for a key.
|
||||
|
||||
Args:
|
||||
key: Key for tracking
|
||||
|
||||
Returns:
|
||||
Number of counters reset
|
||||
"""
|
||||
count = 0
|
||||
async with self._lock:
|
||||
to_remove = [k for k in self._counters if k.endswith(f":{key}")]
|
||||
for k in to_remove:
|
||||
del self._counters[k]
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _get_config(self, limit_name: str) -> RateLimitConfig:
|
||||
"""Get configuration for a rate limit."""
|
||||
if limit_name in self._configs:
|
||||
return self._configs[limit_name]
|
||||
if limit_name in self._default_limits:
|
||||
return self._default_limits[limit_name]
|
||||
# Return default
|
||||
return RateLimitConfig(
|
||||
name=limit_name,
|
||||
limit=60,
|
||||
window_seconds=60,
|
||||
)
|
||||
|
||||
async def _get_counter(
|
||||
self,
|
||||
limit_name: str,
|
||||
key: str,
|
||||
) -> SlidingWindowCounter:
|
||||
"""Get or create a counter."""
|
||||
counter_key = f"{limit_name}:{key}"
|
||||
config = self._get_config(limit_name)
|
||||
|
||||
async with self._lock:
|
||||
if counter_key not in self._counters:
|
||||
self._counters[counter_key] = SlidingWindowCounter(
|
||||
limit=config.limit,
|
||||
window_seconds=config.window_seconds,
|
||||
burst_limit=config.burst_limit,
|
||||
)
|
||||
return self._counters[counter_key]
|
||||
17
backend/app/services/safety/loops/__init__.py
Normal file
17
backend/app/services/safety/loops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Loop Detection Module
|
||||
|
||||
Detects and prevents action loops in agent behavior.
|
||||
"""
|
||||
|
||||
from .detector import (
|
||||
ActionSignature,
|
||||
LoopBreaker,
|
||||
LoopDetector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionSignature",
|
||||
"LoopBreaker",
|
||||
"LoopDetector",
|
||||
]
|
||||
269
backend/app/services/safety/loops/detector.py
Normal file
269
backend/app/services/safety/loops/detector.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Loop Detector
|
||||
|
||||
Detects and prevents action loops in agent behavior.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import Counter, deque
|
||||
from typing import Any
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import LoopDetectedError
|
||||
from ..models import ActionRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionSignature:
|
||||
"""Signature of an action for comparison."""
|
||||
|
||||
def __init__(self, action: ActionRequest) -> None:
|
||||
self.action_type = action.action_type.value
|
||||
self.tool_name = action.tool_name
|
||||
self.resource = action.resource
|
||||
self.args_hash = self._hash_args(action.arguments)
|
||||
|
||||
def _hash_args(self, args: dict[str, Any]) -> str:
|
||||
"""Create a hash of the arguments."""
|
||||
try:
|
||||
serialized = json.dumps(args, sort_keys=True, default=str)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()[:8]
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def exact_key(self) -> str:
|
||||
"""Key for exact match detection."""
|
||||
return f"{self.action_type}:{self.tool_name}:{self.resource}:{self.args_hash}"
|
||||
|
||||
def semantic_key(self) -> str:
|
||||
"""Key for semantic (similar) match detection."""
|
||||
return f"{self.action_type}:{self.tool_name}:{self.resource}"
|
||||
|
||||
def type_key(self) -> str:
|
||||
"""Key for action type only."""
|
||||
return f"{self.action_type}"
|
||||
|
||||
|
||||
class LoopDetector:
|
||||
"""
|
||||
Detects action loops and repetitive behavior.
|
||||
|
||||
Loop Types:
|
||||
- Exact: Same action with same arguments
|
||||
- Semantic: Similar actions (same type/tool/resource, different args)
|
||||
- Oscillation: A→B→A→B patterns
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_size: int | None = None,
|
||||
max_exact_repetitions: int | None = None,
|
||||
max_semantic_repetitions: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the LoopDetector.
|
||||
|
||||
Args:
|
||||
history_size: Size of action history to track
|
||||
max_exact_repetitions: Max allowed exact repetitions
|
||||
max_semantic_repetitions: Max allowed semantic repetitions
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._history_size = history_size or config.loop_history_size
|
||||
self._max_exact = max_exact_repetitions or config.max_repeated_actions
|
||||
self._max_semantic = max_semantic_repetitions or config.max_similar_actions
|
||||
|
||||
# Per-agent history
|
||||
self._histories: dict[str, deque[ActionSignature]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def check(self, action: ActionRequest) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if an action would create a loop.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_loop, loop_type)
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
signature = ActionSignature(action)
|
||||
|
||||
async with self._lock:
|
||||
history = self._get_history(agent_id)
|
||||
|
||||
# Check exact repetition
|
||||
exact_key = signature.exact_key()
|
||||
exact_count = sum(1 for h in history if h.exact_key() == exact_key)
|
||||
if exact_count >= self._max_exact:
|
||||
return True, "exact"
|
||||
|
||||
# Check semantic repetition
|
||||
semantic_key = signature.semantic_key()
|
||||
semantic_count = sum(1 for h in history if h.semantic_key() == semantic_key)
|
||||
if semantic_count >= self._max_semantic:
|
||||
return True, "semantic"
|
||||
|
||||
# Check oscillation (A→B→A→B pattern)
|
||||
if len(history) >= 3:
|
||||
pattern = self._detect_oscillation(history, signature)
|
||||
if pattern:
|
||||
return True, "oscillation"
|
||||
|
||||
return False, None
|
||||
|
||||
async def check_and_raise(self, action: ActionRequest) -> None:
|
||||
"""
|
||||
Check for loops and raise if detected.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Raises:
|
||||
LoopDetectedError: If loop is detected
|
||||
"""
|
||||
is_loop, loop_type = await self.check(action)
|
||||
if is_loop:
|
||||
signature = ActionSignature(action)
|
||||
raise LoopDetectedError(
|
||||
f"Loop detected: {loop_type}",
|
||||
loop_type=loop_type or "unknown",
|
||||
repetition_count=self._max_exact
|
||||
if loop_type == "exact"
|
||||
else self._max_semantic,
|
||||
action_pattern=[signature.semantic_key()],
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
)
|
||||
|
||||
async def record(self, action: ActionRequest) -> None:
|
||||
"""
|
||||
Record an action in history.
|
||||
|
||||
Args:
|
||||
action: The action to record
|
||||
"""
|
||||
agent_id = action.metadata.agent_id
|
||||
signature = ActionSignature(action)
|
||||
|
||||
async with self._lock:
|
||||
history = self._get_history(agent_id)
|
||||
history.append(signature)
|
||||
|
||||
async def clear_history(self, agent_id: str) -> None:
|
||||
"""
|
||||
Clear history for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
"""
|
||||
async with self._lock:
|
||||
if agent_id in self._histories:
|
||||
self._histories[agent_id].clear()
|
||||
|
||||
async def get_stats(self, agent_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get loop detection stats for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
|
||||
Returns:
|
||||
Stats dictionary
|
||||
"""
|
||||
async with self._lock:
|
||||
history = self._get_history(agent_id)
|
||||
|
||||
# Count action types
|
||||
type_counts = Counter(h.type_key() for h in history)
|
||||
semantic_counts = Counter(h.semantic_key() for h in history)
|
||||
|
||||
return {
|
||||
"history_size": len(history),
|
||||
"max_history": self._history_size,
|
||||
"action_type_counts": dict(type_counts),
|
||||
"top_semantic_patterns": semantic_counts.most_common(5),
|
||||
}
|
||||
|
||||
def _get_history(self, agent_id: str) -> deque[ActionSignature]:
|
||||
"""Get or create history for an agent."""
|
||||
if agent_id not in self._histories:
|
||||
self._histories[agent_id] = deque(maxlen=self._history_size)
|
||||
return self._histories[agent_id]
|
||||
|
||||
def _detect_oscillation(
|
||||
self,
|
||||
history: deque[ActionSignature],
|
||||
current: ActionSignature,
|
||||
) -> bool:
|
||||
"""
|
||||
Detect A→B→A→B oscillation pattern.
|
||||
|
||||
Looks at last 4+ actions including current.
|
||||
"""
|
||||
if len(history) < 3:
|
||||
return False
|
||||
|
||||
# Get last 3 actions + current
|
||||
recent = [*list(history)[-3:], current]
|
||||
|
||||
# Check for A→B→A→B pattern
|
||||
if len(recent) >= 4:
|
||||
# Get semantic keys
|
||||
keys = [a.semantic_key() for a in recent[-4:]]
|
||||
|
||||
# Pattern: k[0]==k[2] and k[1]==k[3] and k[0]!=k[1]
|
||||
if keys[0] == keys[2] and keys[1] == keys[3] and keys[0] != keys[1]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class LoopBreaker:
|
||||
"""
|
||||
Strategies for breaking detected loops.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def suggest_alternatives(
|
||||
action: ActionRequest,
|
||||
loop_type: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Suggest alternative actions when loop is detected.
|
||||
|
||||
Args:
|
||||
action: The looping action
|
||||
loop_type: Type of loop detected
|
||||
|
||||
Returns:
|
||||
List of suggestions
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
if loop_type == "exact":
|
||||
suggestions.append(
|
||||
"The same action with identical arguments has been repeated too many times. "
|
||||
"Consider: (1) Verify the action succeeded, (2) Try a different approach, "
|
||||
"(3) Escalate for human review"
|
||||
)
|
||||
elif loop_type == "semantic":
|
||||
suggestions.append(
|
||||
"Similar actions have been repeated too many times. "
|
||||
"Consider: (1) Review if the approach is working, (2) Try an alternative method, "
|
||||
"(3) Request clarification on the goal"
|
||||
)
|
||||
elif loop_type == "oscillation":
|
||||
suggestions.append(
|
||||
"An oscillating pattern was detected (A→B→A→B). "
|
||||
"This usually indicates conflicting goals or a stuck state. "
|
||||
"Consider: (1) Step back and reassess, (2) Request human guidance"
|
||||
)
|
||||
|
||||
return suggestions
|
||||
17
backend/app/services/safety/mcp/__init__.py
Normal file
17
backend/app/services/safety/mcp/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""MCP safety integration."""
|
||||
|
||||
from .integration import (
|
||||
MCPSafetyWrapper,
|
||||
MCPToolCall,
|
||||
MCPToolResult,
|
||||
SafeToolExecutor,
|
||||
create_mcp_wrapper,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MCPSafetyWrapper",
|
||||
"MCPToolCall",
|
||||
"MCPToolResult",
|
||||
"SafeToolExecutor",
|
||||
"create_mcp_wrapper",
|
||||
]
|
||||
409
backend/app/services/safety/mcp/integration.py
Normal file
409
backend/app/services/safety/mcp/integration.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
MCP Safety Integration
|
||||
|
||||
Provides safety-aware wrappers for MCP tool execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
|
||||
from ..audit import AuditLogger
|
||||
from ..emergency import EmergencyControls, get_emergency_controls
|
||||
from ..exceptions import (
|
||||
EmergencyStopError,
|
||||
SafetyError,
|
||||
)
|
||||
from ..guardian import SafetyGuardian, get_safety_guardian
|
||||
from ..models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
AutonomyLevel,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolCall:
|
||||
"""Represents an MCP tool call."""
|
||||
|
||||
tool_name: str
|
||||
arguments: dict[str, Any]
|
||||
server_name: str | None = None
|
||||
project_id: str | None = None
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolResult:
|
||||
"""Result of an MCP tool execution."""
|
||||
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
safety_decision: SafetyDecision = SafetyDecision.ALLOW
|
||||
execution_time_ms: float = 0.0
|
||||
approval_id: str | None = None
|
||||
checkpoint_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class MCPSafetyWrapper:
|
||||
"""
|
||||
Wraps MCP tool execution with safety checks.
|
||||
|
||||
Features:
|
||||
- Pre-execution validation via SafetyGuardian
|
||||
- Permission checking per tool/resource
|
||||
- Budget and rate limit enforcement
|
||||
- Audit logging of all MCP calls
|
||||
- Emergency stop integration
|
||||
- Checkpoint creation for destructive operations
|
||||
"""
|
||||
|
||||
# Tool categories for automatic classification
|
||||
DESTRUCTIVE_TOOLS: ClassVar[set[str]] = {
|
||||
"file_write",
|
||||
"file_delete",
|
||||
"database_mutate",
|
||||
"shell_execute",
|
||||
"git_push",
|
||||
"git_commit",
|
||||
"deploy",
|
||||
}
|
||||
|
||||
READ_ONLY_TOOLS: ClassVar[set[str]] = {
|
||||
"file_read",
|
||||
"database_query",
|
||||
"git_status",
|
||||
"git_log",
|
||||
"list_files",
|
||||
"search",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guardian: SafetyGuardian | None = None,
|
||||
audit_logger: AuditLogger | None = None,
|
||||
emergency_controls: EmergencyControls | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize MCPSafetyWrapper.
|
||||
|
||||
Args:
|
||||
guardian: SafetyGuardian instance (uses singleton if not provided)
|
||||
audit_logger: AuditLogger instance
|
||||
emergency_controls: EmergencyControls instance
|
||||
"""
|
||||
self._guardian = guardian
|
||||
self._audit_logger = audit_logger
|
||||
self._emergency_controls = emergency_controls
|
||||
self._tool_handlers: dict[str, Callable[..., Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _get_guardian(self) -> SafetyGuardian:
|
||||
"""Get or create SafetyGuardian."""
|
||||
if self._guardian is None:
|
||||
self._guardian = await get_safety_guardian()
|
||||
return self._guardian
|
||||
|
||||
async def _get_emergency_controls(self) -> EmergencyControls:
|
||||
"""Get or create EmergencyControls."""
|
||||
if self._emergency_controls is None:
|
||||
self._emergency_controls = await get_emergency_controls()
|
||||
return self._emergency_controls
|
||||
|
||||
def register_tool_handler(
|
||||
self,
|
||||
tool_name: str,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""
|
||||
Register a handler for a tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
handler: Async function to handle the tool call
|
||||
"""
|
||||
self._tool_handlers[tool_name] = handler
|
||||
logger.debug("Registered handler for tool: %s", tool_name)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_call: MCPToolCall,
|
||||
agent_id: str,
|
||||
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
||||
bypass_safety: bool = False,
|
||||
) -> MCPToolResult:
|
||||
"""
|
||||
Execute an MCP tool call with safety checks.
|
||||
|
||||
Args:
|
||||
tool_call: The tool call to execute
|
||||
agent_id: ID of the calling agent
|
||||
autonomy_level: Agent's autonomy level
|
||||
bypass_safety: Bypass safety checks (emergency only)
|
||||
|
||||
Returns:
|
||||
MCPToolResult with execution outcome
|
||||
"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Check emergency controls first
|
||||
emergency = await self._get_emergency_controls()
|
||||
scope = f"agent:{agent_id}"
|
||||
if tool_call.project_id:
|
||||
scope = f"project:{tool_call.project_id}"
|
||||
|
||||
try:
|
||||
await emergency.check_allowed(scope=scope, raise_if_blocked=True)
|
||||
except EmergencyStopError as e:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
metadata={"emergency_stop": True},
|
||||
)
|
||||
|
||||
# Build action request
|
||||
action = self._build_action_request(
|
||||
tool_call=tool_call,
|
||||
agent_id=agent_id,
|
||||
autonomy_level=autonomy_level,
|
||||
)
|
||||
|
||||
# Skip safety checks if bypass is enabled
|
||||
if bypass_safety:
|
||||
logger.warning(
|
||||
"Safety bypass enabled for tool: %s (agent: %s)",
|
||||
tool_call.tool_name,
|
||||
agent_id,
|
||||
)
|
||||
return await self._execute_tool(tool_call, action, start_time)
|
||||
|
||||
# Run safety validation
|
||||
guardian = await self._get_guardian()
|
||||
try:
|
||||
guardian_result = await guardian.validate(action)
|
||||
except SafetyError as e:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
# Handle safety decision
|
||||
if guardian_result.decision == SafetyDecision.DENY:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error="; ".join(guardian_result.reasons),
|
||||
safety_decision=SafetyDecision.DENY,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
if guardian_result.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||
# For now, just return that approval is required
|
||||
# The caller should handle the approval flow
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error="Action requires human approval",
|
||||
safety_decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
approval_id=guardian_result.approval_id,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
result = await self._execute_tool(
|
||||
tool_call,
|
||||
action,
|
||||
start_time,
|
||||
checkpoint_id=guardian_result.checkpoint_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
tool_call: MCPToolCall,
|
||||
action: ActionRequest,
|
||||
start_time: datetime,
|
||||
checkpoint_id: str | None = None,
|
||||
) -> MCPToolResult:
|
||||
"""Execute the actual tool call."""
|
||||
handler = self._tool_handlers.get(tool_call.tool_name)
|
||||
|
||||
if handler is None:
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=f"No handler registered for tool: {tool_call.tool_name}",
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
)
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(**tool_call.arguments)
|
||||
else:
|
||||
result = handler(**tool_call.arguments)
|
||||
|
||||
return MCPToolResult(
|
||||
success=True,
|
||||
result=result,
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s - %s", tool_call.tool_name, e)
|
||||
return MCPToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
safety_decision=SafetyDecision.ALLOW,
|
||||
execution_time_ms=self._elapsed_ms(start_time),
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
def _build_action_request(
|
||||
self,
|
||||
tool_call: MCPToolCall,
|
||||
agent_id: str,
|
||||
autonomy_level: AutonomyLevel,
|
||||
) -> ActionRequest:
|
||||
"""Build an ActionRequest from an MCP tool call."""
|
||||
action_type = self._classify_tool(tool_call.tool_name)
|
||||
|
||||
metadata = ActionMetadata(
|
||||
agent_id=agent_id,
|
||||
session_id=tool_call.context.get("session_id", ""),
|
||||
project_id=tool_call.project_id or "",
|
||||
autonomy_level=autonomy_level,
|
||||
)
|
||||
|
||||
return ActionRequest(
|
||||
action_type=action_type,
|
||||
tool_name=tool_call.tool_name,
|
||||
arguments=tool_call.arguments,
|
||||
resource=tool_call.arguments.get(
|
||||
"path", tool_call.arguments.get("resource")
|
||||
),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _classify_tool(self, tool_name: str) -> ActionType:
|
||||
"""Classify a tool into an action type."""
|
||||
tool_lower = tool_name.lower()
|
||||
|
||||
# Check destructive patterns
|
||||
if any(
|
||||
d in tool_lower for d in ["write", "create", "delete", "remove", "update"]
|
||||
):
|
||||
if "file" in tool_lower:
|
||||
if "delete" in tool_lower or "remove" in tool_lower:
|
||||
return ActionType.FILE_DELETE
|
||||
return ActionType.FILE_WRITE
|
||||
if "database" in tool_lower or "db" in tool_lower:
|
||||
return ActionType.DATABASE_MUTATE
|
||||
|
||||
# Check read patterns
|
||||
if any(r in tool_lower for r in ["read", "get", "list", "search", "query"]):
|
||||
if "file" in tool_lower:
|
||||
return ActionType.FILE_READ
|
||||
if "database" in tool_lower or "db" in tool_lower:
|
||||
return ActionType.DATABASE_QUERY
|
||||
|
||||
# Check specific types
|
||||
if "shell" in tool_lower or "exec" in tool_lower or "bash" in tool_lower:
|
||||
return ActionType.SHELL_COMMAND
|
||||
|
||||
if "git" in tool_lower:
|
||||
return ActionType.GIT_OPERATION
|
||||
|
||||
if "http" in tool_lower or "fetch" in tool_lower or "request" in tool_lower:
|
||||
return ActionType.NETWORK_REQUEST
|
||||
|
||||
if "llm" in tool_lower or "ai" in tool_lower or "claude" in tool_lower:
|
||||
return ActionType.LLM_CALL
|
||||
|
||||
# Default to tool call
|
||||
return ActionType.TOOL_CALL
|
||||
|
||||
def _elapsed_ms(self, start_time: datetime) -> float:
|
||||
"""Calculate elapsed time in milliseconds."""
|
||||
return (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
|
||||
class SafeToolExecutor:
|
||||
"""
|
||||
Context manager for safe tool execution with automatic cleanup.
|
||||
|
||||
Usage:
|
||||
async with SafeToolExecutor(wrapper, tool_call, agent_id) as executor:
|
||||
result = await executor.execute()
|
||||
if result.success:
|
||||
# Use result
|
||||
else:
|
||||
# Handle error or approval required
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrapper: MCPSafetyWrapper,
|
||||
tool_call: MCPToolCall,
|
||||
agent_id: str,
|
||||
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE,
|
||||
) -> None:
|
||||
self._wrapper = wrapper
|
||||
self._tool_call = tool_call
|
||||
self._agent_id = agent_id
|
||||
self._autonomy_level = autonomy_level
|
||||
self._result: MCPToolResult | None = None
|
||||
|
||||
async def __aenter__(self) -> "SafeToolExecutor":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[Exception] | None,
|
||||
exc_val: Exception | None,
|
||||
exc_tb: Any,
|
||||
) -> bool:
|
||||
# Could trigger rollback here if needed
|
||||
return False
|
||||
|
||||
async def execute(self) -> MCPToolResult:
|
||||
"""Execute the tool call."""
|
||||
self._result = await self._wrapper.execute(
|
||||
self._tool_call,
|
||||
self._agent_id,
|
||||
self._autonomy_level,
|
||||
)
|
||||
return self._result
|
||||
|
||||
@property
|
||||
def result(self) -> MCPToolResult | None:
|
||||
"""Get the execution result."""
|
||||
return self._result
|
||||
|
||||
|
||||
# Factory function
|
||||
async def create_mcp_wrapper(
|
||||
guardian: SafetyGuardian | None = None,
|
||||
) -> MCPSafetyWrapper:
|
||||
"""Create an MCPSafetyWrapper with default configuration."""
|
||||
if guardian is None:
|
||||
guardian = await get_safety_guardian()
|
||||
|
||||
return MCPSafetyWrapper(
|
||||
guardian=guardian,
|
||||
emergency_controls=await get_emergency_controls(),
|
||||
)
|
||||
19
backend/app/services/safety/metrics/__init__.py
Normal file
19
backend/app/services/safety/metrics/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Safety metrics collection and export."""
|
||||
|
||||
from .collector import (
|
||||
MetricType,
|
||||
MetricValue,
|
||||
SafetyMetrics,
|
||||
get_safety_metrics,
|
||||
record_mcp_call,
|
||||
record_validation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MetricType",
|
||||
"MetricValue",
|
||||
"SafetyMetrics",
|
||||
"get_safety_metrics",
|
||||
"record_mcp_call",
|
||||
"record_validation",
|
||||
]
|
||||
430
backend/app/services/safety/metrics/collector.py
Normal file
430
backend/app/services/safety/metrics/collector.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
Safety Metrics Collector
|
||||
|
||||
Collects and exposes metrics for the safety framework.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""Types of metrics."""
|
||||
|
||||
COUNTER = "counter"
|
||||
GAUGE = "gauge"
|
||||
HISTOGRAM = "histogram"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricValue:
|
||||
"""A single metric value."""
|
||||
|
||||
name: str
|
||||
metric_type: MetricType
|
||||
value: float
|
||||
labels: dict[str, str] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistogramBucket:
|
||||
"""Histogram bucket for distribution metrics."""
|
||||
|
||||
le: float # Less than or equal
|
||||
count: int = 0
|
||||
|
||||
|
||||
class SafetyMetrics:
|
||||
"""
|
||||
Collects safety framework metrics.
|
||||
|
||||
Metrics tracked:
|
||||
- Action validation counts (by decision type)
|
||||
- Approval request counts and latencies
|
||||
- Budget usage and remaining
|
||||
- Rate limit hits
|
||||
- Loop detections
|
||||
- Emergency events
|
||||
- Content filter matches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize SafetyMetrics."""
|
||||
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
|
||||
self._histograms: dict[str, list[float]] = defaultdict(list)
|
||||
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Initialize histogram buckets
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _init_histogram_buckets(self) -> None:
|
||||
"""Initialize histogram buckets for latency metrics."""
|
||||
latency_buckets = [
|
||||
0.01,
|
||||
0.05,
|
||||
0.1,
|
||||
0.25,
|
||||
0.5,
|
||||
1.0,
|
||||
2.5,
|
||||
5.0,
|
||||
10.0,
|
||||
float("inf"),
|
||||
]
|
||||
|
||||
for name in [
|
||||
"validation_latency_seconds",
|
||||
"approval_latency_seconds",
|
||||
"mcp_execution_latency_seconds",
|
||||
]:
|
||||
self._histogram_buckets[name] = [
|
||||
HistogramBucket(le=b) for b in latency_buckets
|
||||
]
|
||||
|
||||
# Counter methods
|
||||
|
||||
async def inc_validations(
|
||||
self,
|
||||
decision: str,
|
||||
agent_id: str | None = None,
|
||||
) -> None:
|
||||
"""Increment validation counter."""
|
||||
async with self._lock:
|
||||
labels = f"decision={decision}"
|
||||
if agent_id:
|
||||
labels += f",agent_id={agent_id}"
|
||||
self._counters["safety_validations_total"][labels] += 1
|
||||
|
||||
async def inc_approvals_requested(self, urgency: str = "normal") -> None:
|
||||
"""Increment approval requests counter."""
|
||||
async with self._lock:
|
||||
labels = f"urgency={urgency}"
|
||||
self._counters["safety_approvals_requested_total"][labels] += 1
|
||||
|
||||
async def inc_approvals_granted(self) -> None:
|
||||
"""Increment approvals granted counter."""
|
||||
async with self._lock:
|
||||
self._counters["safety_approvals_granted_total"][""] += 1
|
||||
|
||||
async def inc_approvals_denied(self, reason: str = "manual") -> None:
|
||||
"""Increment approvals denied counter."""
|
||||
async with self._lock:
|
||||
labels = f"reason={reason}"
|
||||
self._counters["safety_approvals_denied_total"][labels] += 1
|
||||
|
||||
async def inc_rate_limit_exceeded(self, limit_type: str) -> None:
|
||||
"""Increment rate limit exceeded counter."""
|
||||
async with self._lock:
|
||||
labels = f"limit_type={limit_type}"
|
||||
self._counters["safety_rate_limit_exceeded_total"][labels] += 1
|
||||
|
||||
async def inc_budget_exceeded(self, budget_type: str) -> None:
|
||||
"""Increment budget exceeded counter."""
|
||||
async with self._lock:
|
||||
labels = f"budget_type={budget_type}"
|
||||
self._counters["safety_budget_exceeded_total"][labels] += 1
|
||||
|
||||
async def inc_loops_detected(self, loop_type: str) -> None:
|
||||
"""Increment loop detection counter."""
|
||||
async with self._lock:
|
||||
labels = f"loop_type={loop_type}"
|
||||
self._counters["safety_loops_detected_total"][labels] += 1
|
||||
|
||||
async def inc_emergency_events(self, event_type: str, scope: str) -> None:
|
||||
"""Increment emergency events counter."""
|
||||
async with self._lock:
|
||||
labels = f"event_type={event_type},scope={scope}"
|
||||
self._counters["safety_emergency_events_total"][labels] += 1
|
||||
|
||||
async def inc_content_filtered(self, category: str, action: str) -> None:
|
||||
"""Increment content filter counter."""
|
||||
async with self._lock:
|
||||
labels = f"category={category},action={action}"
|
||||
self._counters["safety_content_filtered_total"][labels] += 1
|
||||
|
||||
async def inc_checkpoints_created(self) -> None:
|
||||
"""Increment checkpoints created counter."""
|
||||
async with self._lock:
|
||||
self._counters["safety_checkpoints_created_total"][""] += 1
|
||||
|
||||
async def inc_rollbacks_executed(self, success: bool) -> None:
|
||||
"""Increment rollbacks counter."""
|
||||
async with self._lock:
|
||||
labels = f"success={str(success).lower()}"
|
||||
self._counters["safety_rollbacks_total"][labels] += 1
|
||||
|
||||
async def inc_mcp_calls(self, tool_name: str, success: bool) -> None:
|
||||
"""Increment MCP tool calls counter."""
|
||||
async with self._lock:
|
||||
labels = f"tool_name={tool_name},success={str(success).lower()}"
|
||||
self._counters["safety_mcp_calls_total"][labels] += 1
|
||||
|
||||
# Gauge methods
|
||||
|
||||
async def set_budget_remaining(
|
||||
self,
|
||||
scope: str,
|
||||
budget_type: str,
|
||||
remaining: float,
|
||||
) -> None:
|
||||
"""Set remaining budget gauge."""
|
||||
async with self._lock:
|
||||
labels = f"scope={scope},budget_type={budget_type}"
|
||||
self._gauges["safety_budget_remaining"][labels] = remaining
|
||||
|
||||
async def set_rate_limit_remaining(
|
||||
self,
|
||||
scope: str,
|
||||
limit_type: str,
|
||||
remaining: int,
|
||||
) -> None:
|
||||
"""Set remaining rate limit gauge."""
|
||||
async with self._lock:
|
||||
labels = f"scope={scope},limit_type={limit_type}"
|
||||
self._gauges["safety_rate_limit_remaining"][labels] = float(remaining)
|
||||
|
||||
async def set_pending_approvals(self, count: int) -> None:
|
||||
"""Set pending approvals gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["safety_pending_approvals"][""] = float(count)
|
||||
|
||||
async def set_active_checkpoints(self, count: int) -> None:
|
||||
"""Set active checkpoints gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["safety_active_checkpoints"][""] = float(count)
|
||||
|
||||
async def set_emergency_state(self, scope: str, state: str) -> None:
|
||||
"""Set emergency state gauge (0=normal, 1=paused, 2=stopped)."""
|
||||
async with self._lock:
|
||||
state_value = {"normal": 0, "paused": 1, "stopped": 2}.get(state, -1)
|
||||
labels = f"scope={scope}"
|
||||
self._gauges["safety_emergency_state"][labels] = float(state_value)
|
||||
|
||||
# Histogram methods
|
||||
|
||||
async def observe_validation_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe validation latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("validation_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_approval_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe approval latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("approval_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_mcp_execution_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe MCP execution latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("mcp_execution_latency_seconds", latency_seconds)
|
||||
|
||||
def _observe_histogram(self, name: str, value: float) -> None:
|
||||
"""Record a value in a histogram."""
|
||||
self._histograms[name].append(value)
|
||||
|
||||
# Update buckets
|
||||
if name in self._histogram_buckets:
|
||||
for bucket in self._histogram_buckets[name]:
|
||||
if value <= bucket.le:
|
||||
bucket.count += 1
|
||||
|
||||
# Export methods
|
||||
|
||||
async def get_all_metrics(self) -> list[MetricValue]:
|
||||
"""Get all metrics as MetricValue objects."""
|
||||
metrics: list[MetricValue] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
for labels_str, value in counter.items():
|
||||
labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(value),
|
||||
labels=labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
gauge_labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.GAUGE,
|
||||
value=gauge_value,
|
||||
labels=gauge_labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export histogram summaries
|
||||
for name, values in self._histograms.items():
|
||||
if values:
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_count",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(len(values)),
|
||||
)
|
||||
)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_sum",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=sum(values),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_prometheus_format(self) -> str:
|
||||
"""Export metrics in Prometheus text format."""
|
||||
lines: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
lines.append(f"# TYPE {name} counter")
|
||||
for labels_str, value in counter.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {value}")
|
||||
else:
|
||||
lines.append(f"{name} {value}")
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
lines.append(f"# TYPE {name} gauge")
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
|
||||
else:
|
||||
lines.append(f"{name} {gauge_value}")
|
||||
|
||||
# Export histograms
|
||||
for name, buckets in self._histogram_buckets.items():
|
||||
lines.append(f"# TYPE {name} histogram")
|
||||
for bucket in buckets:
|
||||
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
|
||||
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
|
||||
|
||||
if name in self._histograms:
|
||||
values = self._histograms[name]
|
||||
lines.append(f"{name}_count {len(values)}")
|
||||
lines.append(f"{name}_sum {sum(values)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of key metrics."""
|
||||
async with self._lock:
|
||||
total_validations = sum(self._counters["safety_validations_total"].values())
|
||||
denied_validations = sum(
|
||||
v
|
||||
for k, v in self._counters["safety_validations_total"].items()
|
||||
if "decision=deny" in k
|
||||
)
|
||||
|
||||
return {
|
||||
"total_validations": total_validations,
|
||||
"denied_validations": denied_validations,
|
||||
"approval_requests": sum(
|
||||
self._counters["safety_approvals_requested_total"].values()
|
||||
),
|
||||
"approvals_granted": sum(
|
||||
self._counters["safety_approvals_granted_total"].values()
|
||||
),
|
||||
"approvals_denied": sum(
|
||||
self._counters["safety_approvals_denied_total"].values()
|
||||
),
|
||||
"rate_limit_hits": sum(
|
||||
self._counters["safety_rate_limit_exceeded_total"].values()
|
||||
),
|
||||
"budget_exceeded": sum(
|
||||
self._counters["safety_budget_exceeded_total"].values()
|
||||
),
|
||||
"loops_detected": sum(
|
||||
self._counters["safety_loops_detected_total"].values()
|
||||
),
|
||||
"emergency_events": sum(
|
||||
self._counters["safety_emergency_events_total"].values()
|
||||
),
|
||||
"content_filtered": sum(
|
||||
self._counters["safety_content_filtered_total"].values()
|
||||
),
|
||||
"checkpoints_created": sum(
|
||||
self._counters["safety_checkpoints_created_total"].values()
|
||||
),
|
||||
"rollbacks_executed": sum(
|
||||
self._counters["safety_rollbacks_total"].values()
|
||||
),
|
||||
"mcp_calls": sum(self._counters["safety_mcp_calls_total"].values()),
|
||||
"pending_approvals": self._gauges.get(
|
||||
"safety_pending_approvals", {}
|
||||
).get("", 0),
|
||||
"active_checkpoints": self._gauges.get(
|
||||
"safety_active_checkpoints", {}
|
||||
).get("", 0),
|
||||
}
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset all metrics."""
|
||||
async with self._lock:
|
||||
self._counters.clear()
|
||||
self._gauges.clear()
|
||||
self._histograms.clear()
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _parse_labels(self, labels_str: str) -> dict[str, str]:
|
||||
"""Parse labels string into dictionary."""
|
||||
if not labels_str:
|
||||
return {}
|
||||
|
||||
labels = {}
|
||||
for pair in labels_str.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
labels[key.strip()] = value.strip()
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_metrics: SafetyMetrics | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_safety_metrics() -> SafetyMetrics:
|
||||
"""Get the singleton SafetyMetrics instance."""
|
||||
global _metrics
|
||||
|
||||
async with _lock:
|
||||
if _metrics is None:
|
||||
_metrics = SafetyMetrics()
|
||||
return _metrics
|
||||
|
||||
|
||||
# Convenience functions
|
||||
async def record_validation(decision: str, agent_id: str | None = None) -> None:
|
||||
"""Record a validation event."""
|
||||
metrics = await get_safety_metrics()
|
||||
await metrics.inc_validations(decision, agent_id)
|
||||
|
||||
|
||||
async def record_mcp_call(tool_name: str, success: bool, latency_ms: float) -> None:
|
||||
"""Record an MCP tool call."""
|
||||
metrics = await get_safety_metrics()
|
||||
await metrics.inc_mcp_calls(tool_name, success)
|
||||
await metrics.observe_mcp_execution_latency(latency_ms / 1000)
|
||||
470
backend/app/services/safety/models.py
Normal file
470
backend/app/services/safety/models.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
Safety Framework Models
|
||||
|
||||
Core Pydantic models for actions, events, policies, and safety decisions.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ============================================================================
|
||||
# Enums
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
"""Types of actions that can be performed."""
|
||||
|
||||
TOOL_CALL = "tool_call"
|
||||
FILE_READ = "file_read"
|
||||
FILE_WRITE = "file_write"
|
||||
FILE_DELETE = "file_delete"
|
||||
API_CALL = "api_call"
|
||||
DATABASE_QUERY = "database_query"
|
||||
DATABASE_MUTATE = "database_mutate"
|
||||
GIT_OPERATION = "git_operation"
|
||||
SHELL_COMMAND = "shell_command"
|
||||
LLM_CALL = "llm_call"
|
||||
NETWORK_REQUEST = "network_request"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ResourceType(str, Enum):
|
||||
"""Types of resources that can be accessed."""
|
||||
|
||||
FILE = "file"
|
||||
DATABASE = "database"
|
||||
API = "api"
|
||||
NETWORK = "network"
|
||||
GIT = "git"
|
||||
SHELL = "shell"
|
||||
LLM = "llm"
|
||||
MEMORY = "memory"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class PermissionLevel(str, Enum):
|
||||
"""Permission levels for resource access."""
|
||||
|
||||
NONE = "none"
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
EXECUTE = "execute"
|
||||
DELETE = "delete"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class AutonomyLevel(str, Enum):
|
||||
"""Autonomy levels for agent operation."""
|
||||
|
||||
FULL_CONTROL = "full_control" # Approve every action
|
||||
MILESTONE = "milestone" # Approve at milestones
|
||||
AUTONOMOUS = "autonomous" # Only major decisions
|
||||
|
||||
|
||||
class SafetyDecision(str, Enum):
|
||||
"""Result of safety validation."""
|
||||
|
||||
ALLOW = "allow"
|
||||
DENY = "deny"
|
||||
REQUIRE_APPROVAL = "require_approval"
|
||||
DELAY = "delay"
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
"""Status of approval request."""
|
||||
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
TIMEOUT = "timeout"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class AuditEventType(str, Enum):
|
||||
"""Types of audit events."""
|
||||
|
||||
ACTION_REQUESTED = "action_requested"
|
||||
ACTION_VALIDATED = "action_validated"
|
||||
ACTION_DENIED = "action_denied"
|
||||
ACTION_EXECUTED = "action_executed"
|
||||
ACTION_FAILED = "action_failed"
|
||||
APPROVAL_REQUESTED = "approval_requested"
|
||||
APPROVAL_GRANTED = "approval_granted"
|
||||
APPROVAL_DENIED = "approval_denied"
|
||||
APPROVAL_TIMEOUT = "approval_timeout"
|
||||
CHECKPOINT_CREATED = "checkpoint_created"
|
||||
ROLLBACK_STARTED = "rollback_started"
|
||||
ROLLBACK_COMPLETED = "rollback_completed"
|
||||
ROLLBACK_FAILED = "rollback_failed"
|
||||
BUDGET_WARNING = "budget_warning"
|
||||
BUDGET_EXCEEDED = "budget_exceeded"
|
||||
RATE_LIMITED = "rate_limited"
|
||||
LOOP_DETECTED = "loop_detected"
|
||||
EMERGENCY_STOP = "emergency_stop"
|
||||
POLICY_VIOLATION = "policy_violation"
|
||||
CONTENT_FILTERED = "content_filtered"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Action Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ActionMetadata(BaseModel):
|
||||
"""Metadata associated with an action."""
|
||||
|
||||
agent_id: str = Field(..., description="ID of the agent performing the action")
|
||||
project_id: str | None = Field(None, description="ID of the project context")
|
||||
session_id: str | None = Field(None, description="ID of the current session")
|
||||
task_id: str | None = Field(None, description="ID of the current task")
|
||||
parent_action_id: str | None = Field(None, description="ID of the parent action")
|
||||
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
|
||||
user_id: str | None = Field(None, description="ID of the user who initiated")
|
||||
autonomy_level: AutonomyLevel = Field(
|
||||
default=AutonomyLevel.MILESTONE,
|
||||
description="Current autonomy level",
|
||||
)
|
||||
context: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional context",
|
||||
)
|
||||
|
||||
|
||||
class ActionRequest(BaseModel):
|
||||
"""Request to perform an action."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
action_type: ActionType = Field(..., description="Type of action to perform")
|
||||
tool_name: str | None = Field(None, description="Name of the tool to call")
|
||||
resource: str | None = Field(None, description="Resource being accessed")
|
||||
resource_type: ResourceType | None = Field(None, description="Type of resource")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Action arguments",
|
||||
)
|
||||
metadata: ActionMetadata = Field(..., description="Action metadata")
|
||||
estimated_cost_tokens: int = Field(0, description="Estimated token cost")
|
||||
estimated_cost_usd: float = Field(0.0, description="Estimated USD cost")
|
||||
is_destructive: bool = Field(False, description="Whether action is destructive")
|
||||
is_reversible: bool = Field(True, description="Whether action can be rolled back")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class ActionResult(BaseModel):
|
||||
"""Result of an executed action."""
|
||||
|
||||
action_id: str = Field(..., description="ID of the action")
|
||||
success: bool = Field(..., description="Whether action succeeded")
|
||||
data: Any = Field(None, description="Action result data")
|
||||
error: str | None = Field(None, description="Error message if failed")
|
||||
error_code: str | None = Field(None, description="Error code if failed")
|
||||
execution_time_ms: float = Field(0.0, description="Execution time in ms")
|
||||
actual_cost_tokens: int = Field(0, description="Actual token cost")
|
||||
actual_cost_usd: float = Field(0.0, description="Actual USD cost")
|
||||
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Validation Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ValidationRule(BaseModel):
|
||||
"""A single validation rule."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str = Field(..., description="Rule name")
|
||||
description: str | None = Field(None, description="Rule description")
|
||||
priority: int = Field(0, description="Rule priority (higher = evaluated first)")
|
||||
enabled: bool = Field(True, description="Whether rule is enabled")
|
||||
|
||||
# Rule conditions
|
||||
action_types: list[ActionType] | None = Field(
|
||||
None, description="Action types this rule applies to"
|
||||
)
|
||||
tool_patterns: list[str] | None = Field(
|
||||
None, description="Tool name patterns (supports wildcards)"
|
||||
)
|
||||
resource_patterns: list[str] | None = Field(
|
||||
None, description="Resource patterns (supports wildcards)"
|
||||
)
|
||||
agent_ids: list[str] | None = Field(
|
||||
None, description="Agent IDs this rule applies to"
|
||||
)
|
||||
|
||||
# Rule decision
|
||||
decision: SafetyDecision = Field(..., description="Decision when rule matches")
|
||||
reason: str | None = Field(None, description="Reason for decision")
|
||||
|
||||
|
||||
class ValidationResult(BaseModel):
|
||||
"""Result of action validation."""
|
||||
|
||||
action_id: str = Field(..., description="ID of the validated action")
|
||||
decision: SafetyDecision = Field(..., description="Validation decision")
|
||||
applied_rules: list[str] = Field(
|
||||
default_factory=list, description="IDs of applied rules"
|
||||
)
|
||||
reasons: list[str] = Field(default_factory=list, description="Reasons for decision")
|
||||
approval_id: str | None = Field(None, description="Approval request ID if needed")
|
||||
retry_after_seconds: float | None = Field(
|
||||
None, description="Retry delay if rate limited"
|
||||
)
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Budget Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class BudgetScope(str, Enum):
|
||||
"""Scope of a budget limit."""
|
||||
|
||||
SESSION = "session"
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
PROJECT = "project"
|
||||
AGENT = "agent"
|
||||
|
||||
|
||||
class BudgetStatus(BaseModel):
|
||||
"""Current budget status."""
|
||||
|
||||
scope: BudgetScope = Field(..., description="Budget scope")
|
||||
scope_id: str = Field(..., description="ID within scope (session/agent/project)")
|
||||
tokens_used: int = Field(0, description="Tokens used in this scope")
|
||||
tokens_limit: int = Field(100000, description="Token limit for this scope")
|
||||
cost_used_usd: float = Field(0.0, description="USD spent in this scope")
|
||||
cost_limit_usd: float = Field(10.0, description="USD limit for this scope")
|
||||
tokens_remaining: int = Field(0, description="Remaining tokens")
|
||||
cost_remaining_usd: float = Field(0.0, description="Remaining USD budget")
|
||||
warning_threshold: float = Field(0.8, description="Warn at this usage fraction")
|
||||
is_warning: bool = Field(False, description="Whether at warning level")
|
||||
is_exceeded: bool = Field(False, description="Whether budget exceeded")
|
||||
reset_at: datetime | None = Field(None, description="When budget resets")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Rate Limit Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RateLimitConfig(BaseModel):
|
||||
"""Configuration for a rate limit."""
|
||||
|
||||
name: str = Field(..., description="Rate limit name")
|
||||
limit: int = Field(..., description="Maximum allowed in window")
|
||||
window_seconds: int = Field(60, description="Time window in seconds")
|
||||
burst_limit: int | None = Field(None, description="Burst allowance")
|
||||
slowdown_threshold: float = Field(0.8, description="Start slowing at this fraction")
|
||||
|
||||
|
||||
class RateLimitStatus(BaseModel):
|
||||
"""Current rate limit status."""
|
||||
|
||||
name: str = Field(..., description="Rate limit name")
|
||||
current_count: int = Field(0, description="Current count in window")
|
||||
limit: int = Field(..., description="Maximum allowed")
|
||||
window_seconds: int = Field(..., description="Time window")
|
||||
remaining: int = Field(..., description="Remaining in window")
|
||||
reset_at: datetime = Field(..., description="When window resets")
|
||||
is_limited: bool = Field(False, description="Whether currently limited")
|
||||
retry_after_seconds: float = Field(0.0, description="Seconds until retry")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Approval Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""Request for human approval."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
action: ActionRequest = Field(..., description="Action requiring approval")
|
||||
reason: str = Field(..., description="Why approval is required")
|
||||
urgency: str = Field("normal", description="Urgency level")
|
||||
timeout_seconds: int = Field(300, description="Timeout for approval")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = Field(None, description="When request expires")
|
||||
suggested_action: str | None = Field(None, description="Suggested response")
|
||||
context: dict[str, Any] = Field(default_factory=dict, description="Extra context")
|
||||
|
||||
|
||||
class ApprovalResponse(BaseModel):
|
||||
"""Response to an approval request."""
|
||||
|
||||
request_id: str = Field(..., description="ID of the approval request")
|
||||
status: ApprovalStatus = Field(..., description="Approval status")
|
||||
decided_by: str | None = Field(None, description="Who made the decision")
|
||||
reason: str | None = Field(None, description="Reason for decision")
|
||||
modifications: dict[str, Any] | None = Field(
|
||||
None, description="Modifications to action"
|
||||
)
|
||||
decided_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Checkpoint/Rollback Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CheckpointType(str, Enum):
|
||||
"""Types of checkpoints."""
|
||||
|
||||
FILE = "file"
|
||||
DATABASE = "database"
|
||||
GIT = "git"
|
||||
COMPOSITE = "composite"
|
||||
|
||||
|
||||
class Checkpoint(BaseModel):
|
||||
"""A rollback checkpoint."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
checkpoint_type: CheckpointType = Field(..., description="Type of checkpoint")
|
||||
action_id: str = Field(..., description="Action this checkpoint is for")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = Field(None, description="When checkpoint expires")
|
||||
data: dict[str, Any] = Field(default_factory=dict, description="Checkpoint data")
|
||||
description: str | None = Field(None, description="Description of checkpoint")
|
||||
is_valid: bool = Field(True, description="Whether checkpoint is still valid")
|
||||
|
||||
|
||||
class RollbackResult(BaseModel):
|
||||
"""Result of a rollback operation."""
|
||||
|
||||
checkpoint_id: str = Field(..., description="ID of checkpoint rolled back to")
|
||||
success: bool = Field(..., description="Whether rollback succeeded")
|
||||
actions_rolled_back: list[str] = Field(
|
||||
default_factory=list, description="IDs of rolled back actions"
|
||||
)
|
||||
failed_actions: list[str] = Field(
|
||||
default_factory=list, description="IDs of actions that failed to rollback"
|
||||
)
|
||||
error: str | None = Field(None, description="Error message if failed")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audit Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuditEvent(BaseModel):
|
||||
"""An audit log event."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
event_type: AuditEventType = Field(..., description="Type of audit event")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
agent_id: str | None = Field(None, description="Agent ID if applicable")
|
||||
action_id: str | None = Field(None, description="Action ID if applicable")
|
||||
project_id: str | None = Field(None, description="Project ID if applicable")
|
||||
session_id: str | None = Field(None, description="Session ID if applicable")
|
||||
user_id: str | None = Field(None, description="User ID if applicable")
|
||||
decision: SafetyDecision | None = Field(None, description="Safety decision")
|
||||
details: dict[str, Any] = Field(default_factory=dict, description="Event details")
|
||||
correlation_id: str | None = Field(None, description="Correlation ID for tracing")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Policy Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SafetyPolicy(BaseModel):
|
||||
"""A complete safety policy configuration."""
|
||||
|
||||
name: str = Field(..., description="Policy name")
|
||||
description: str | None = Field(None, description="Policy description")
|
||||
version: str = Field("1.0.0", description="Policy version")
|
||||
enabled: bool = Field(True, description="Whether policy is enabled")
|
||||
|
||||
# Cost controls
|
||||
max_tokens_per_session: int = Field(100_000, description="Max tokens per session")
|
||||
max_tokens_per_day: int = Field(1_000_000, description="Max tokens per day")
|
||||
max_cost_per_session_usd: float = Field(10.0, description="Max USD per session")
|
||||
max_cost_per_day_usd: float = Field(100.0, description="Max USD per day")
|
||||
|
||||
# Rate limits
|
||||
max_actions_per_minute: int = Field(60, description="Max actions per minute")
|
||||
max_llm_calls_per_minute: int = Field(20, description="Max LLM calls per minute")
|
||||
max_file_operations_per_minute: int = Field(
|
||||
100, description="Max file ops per minute"
|
||||
)
|
||||
|
||||
# Permissions
|
||||
allowed_tools: list[str] = Field(
|
||||
default_factory=lambda: ["*"],
|
||||
description="Allowed tool patterns",
|
||||
)
|
||||
denied_tools: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Denied tool patterns",
|
||||
)
|
||||
allowed_file_patterns: list[str] = Field(
|
||||
default_factory=lambda: ["**/*"],
|
||||
description="Allowed file patterns",
|
||||
)
|
||||
denied_file_patterns: list[str] = Field(
|
||||
default_factory=lambda: ["**/.env", "**/secrets/**"],
|
||||
description="Denied file patterns",
|
||||
)
|
||||
|
||||
# HITL
|
||||
require_approval_for: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"delete_file",
|
||||
"push_to_remote",
|
||||
"deploy_to_production",
|
||||
"modify_critical_config",
|
||||
],
|
||||
description="Actions requiring approval",
|
||||
)
|
||||
|
||||
# Loop detection
|
||||
max_repeated_actions: int = Field(5, description="Max exact repetitions")
|
||||
max_similar_actions: int = Field(10, description="Max similar actions")
|
||||
|
||||
# Sandbox
|
||||
require_sandbox: bool = Field(False, description="Require sandbox execution")
|
||||
sandbox_timeout_seconds: int = Field(300, description="Sandbox timeout")
|
||||
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit")
|
||||
|
||||
# Validation rules
|
||||
validation_rules: list[ValidationRule] = Field(
|
||||
default_factory=list,
|
||||
description="Custom validation rules",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Guardian Result Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class GuardianResult(BaseModel):
|
||||
"""Result of SafetyGuardian evaluation."""
|
||||
|
||||
action_id: str = Field(..., description="ID of the action")
|
||||
allowed: bool = Field(..., description="Whether action is allowed")
|
||||
decision: SafetyDecision = Field(..., description="Safety decision")
|
||||
reasons: list[str] = Field(default_factory=list, description="Decision reasons")
|
||||
approval_id: str | None = Field(None, description="Approval ID if needed")
|
||||
checkpoint_id: str | None = Field(None, description="Checkpoint ID if created")
|
||||
retry_after_seconds: float | None = Field(None, description="Retry delay")
|
||||
modified_action: ActionRequest | None = Field(
|
||||
None, description="Modified action if changed"
|
||||
)
|
||||
audit_events: list[AuditEvent] = Field(
|
||||
default_factory=list, description="Generated audit events"
|
||||
)
|
||||
15
backend/app/services/safety/permissions/__init__.py
Normal file
15
backend/app/services/safety/permissions/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Permission Management Module
|
||||
|
||||
Agent permissions for resource access.
|
||||
"""
|
||||
|
||||
from .manager import (
|
||||
PermissionGrant,
|
||||
PermissionManager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PermissionGrant",
|
||||
"PermissionManager",
|
||||
]
|
||||
384
backend/app/services/safety/permissions/manager.py
Normal file
384
backend/app/services/safety/permissions/manager.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Permission Manager
|
||||
|
||||
Manages permissions for agent actions on resources.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from ..exceptions import PermissionDeniedError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
PermissionLevel,
|
||||
ResourceType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PermissionGrant:
|
||||
"""A permission grant for an agent on a resource."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource_pattern: str,
|
||||
resource_type: ResourceType,
|
||||
level: PermissionLevel,
|
||||
*,
|
||||
expires_at: datetime | None = None,
|
||||
granted_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> None:
|
||||
self.id = str(uuid4())
|
||||
self.agent_id = agent_id
|
||||
self.resource_pattern = resource_pattern
|
||||
self.resource_type = resource_type
|
||||
self.level = level
|
||||
self.expires_at = expires_at
|
||||
self.granted_by = granted_by
|
||||
self.reason = reason
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the grant has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def matches(self, resource: str, resource_type: ResourceType) -> bool:
|
||||
"""Check if this grant applies to a resource."""
|
||||
if self.resource_type != resource_type:
|
||||
return False
|
||||
return fnmatch.fnmatch(resource, self.resource_pattern)
|
||||
|
||||
def allows(self, required_level: PermissionLevel) -> bool:
|
||||
"""Check if this grant allows the required permission level."""
|
||||
# Permission level hierarchy
|
||||
hierarchy = {
|
||||
PermissionLevel.NONE: 0,
|
||||
PermissionLevel.READ: 1,
|
||||
PermissionLevel.WRITE: 2,
|
||||
PermissionLevel.EXECUTE: 3,
|
||||
PermissionLevel.DELETE: 4,
|
||||
PermissionLevel.ADMIN: 5,
|
||||
}
|
||||
|
||||
return hierarchy[self.level] >= hierarchy[required_level]
|
||||
|
||||
|
||||
class PermissionManager:
|
||||
"""
|
||||
Manages permissions for agent access to resources.
|
||||
|
||||
Features:
|
||||
- Permission grants by agent/resource pattern
|
||||
- Permission inheritance (project → agent → action)
|
||||
- Temporary permissions with expiration
|
||||
- Least-privilege defaults
|
||||
- Permission escalation logging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_deny: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the PermissionManager.
|
||||
|
||||
Args:
|
||||
default_deny: If True, deny access unless explicitly granted
|
||||
"""
|
||||
self._grants: list[PermissionGrant] = []
|
||||
self._default_deny = default_deny
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Default permissions for common resources
|
||||
self._default_permissions: dict[ResourceType, PermissionLevel] = {
|
||||
ResourceType.FILE: PermissionLevel.READ,
|
||||
ResourceType.DATABASE: PermissionLevel.READ,
|
||||
ResourceType.API: PermissionLevel.READ,
|
||||
ResourceType.GIT: PermissionLevel.READ,
|
||||
ResourceType.LLM: PermissionLevel.EXECUTE,
|
||||
ResourceType.SHELL: PermissionLevel.NONE,
|
||||
ResourceType.NETWORK: PermissionLevel.READ,
|
||||
}
|
||||
|
||||
async def grant(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource_pattern: str,
|
||||
resource_type: ResourceType,
|
||||
level: PermissionLevel,
|
||||
*,
|
||||
duration_seconds: int | None = None,
|
||||
granted_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> PermissionGrant:
|
||||
"""
|
||||
Grant a permission to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
resource_pattern: Pattern for matching resources (supports wildcards)
|
||||
resource_type: Type of resource
|
||||
level: Permission level to grant
|
||||
duration_seconds: Optional duration for temporary permission
|
||||
granted_by: Who granted the permission
|
||||
reason: Reason for granting
|
||||
|
||||
Returns:
|
||||
The created permission grant
|
||||
"""
|
||||
expires_at = None
|
||||
if duration_seconds:
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=duration_seconds)
|
||||
|
||||
grant = PermissionGrant(
|
||||
agent_id=agent_id,
|
||||
resource_pattern=resource_pattern,
|
||||
resource_type=resource_type,
|
||||
level=level,
|
||||
expires_at=expires_at,
|
||||
granted_by=granted_by,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._grants.append(grant)
|
||||
|
||||
logger.info(
|
||||
"Permission granted: agent=%s, resource=%s, type=%s, level=%s",
|
||||
agent_id,
|
||||
resource_pattern,
|
||||
resource_type.value,
|
||||
level.value,
|
||||
)
|
||||
|
||||
return grant
|
||||
|
||||
async def revoke(self, grant_id: str) -> bool:
|
||||
"""
|
||||
Revoke a permission grant.
|
||||
|
||||
Args:
|
||||
grant_id: ID of the grant to revoke
|
||||
|
||||
Returns:
|
||||
True if grant was found and revoked
|
||||
"""
|
||||
async with self._lock:
|
||||
for i, grant in enumerate(self._grants):
|
||||
if grant.id == grant_id:
|
||||
del self._grants[i]
|
||||
logger.info("Permission revoked: %s", grant_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def revoke_all(self, agent_id: str) -> int:
|
||||
"""
|
||||
Revoke all permissions for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
|
||||
Returns:
|
||||
Number of grants revoked
|
||||
"""
|
||||
async with self._lock:
|
||||
original_count = len(self._grants)
|
||||
self._grants = [g for g in self._grants if g.agent_id != agent_id]
|
||||
revoked = original_count - len(self._grants)
|
||||
|
||||
if revoked:
|
||||
logger.info("Revoked %d permissions for agent %s", revoked, agent_id)
|
||||
|
||||
return revoked
|
||||
|
||||
async def check(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource: str,
|
||||
resource_type: ResourceType,
|
||||
required_level: PermissionLevel,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an agent has permission to access a resource.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
resource: Resource to access
|
||||
resource_type: Type of resource
|
||||
required_level: Required permission level
|
||||
|
||||
Returns:
|
||||
True if access is allowed
|
||||
"""
|
||||
# Clean up expired grants
|
||||
await self._cleanup_expired()
|
||||
|
||||
async with self._lock:
|
||||
for grant in self._grants:
|
||||
if grant.agent_id != agent_id:
|
||||
continue
|
||||
|
||||
if grant.is_expired():
|
||||
continue
|
||||
|
||||
if grant.matches(resource, resource_type):
|
||||
if grant.allows(required_level):
|
||||
return True
|
||||
|
||||
# Check default permissions
|
||||
if not self._default_deny:
|
||||
default_level = self._default_permissions.get(
|
||||
resource_type, PermissionLevel.NONE
|
||||
)
|
||||
hierarchy = {
|
||||
PermissionLevel.NONE: 0,
|
||||
PermissionLevel.READ: 1,
|
||||
PermissionLevel.WRITE: 2,
|
||||
PermissionLevel.EXECUTE: 3,
|
||||
PermissionLevel.DELETE: 4,
|
||||
PermissionLevel.ADMIN: 5,
|
||||
}
|
||||
if hierarchy[default_level] >= hierarchy[required_level]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def check_action(self, action: ActionRequest) -> bool:
|
||||
"""
|
||||
Check if an action is permitted.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
True if action is allowed
|
||||
"""
|
||||
# Determine required permission level from action type
|
||||
level_map = {
|
||||
ActionType.FILE_READ: PermissionLevel.READ,
|
||||
ActionType.FILE_WRITE: PermissionLevel.WRITE,
|
||||
ActionType.FILE_DELETE: PermissionLevel.DELETE,
|
||||
ActionType.DATABASE_QUERY: PermissionLevel.READ,
|
||||
ActionType.DATABASE_MUTATE: PermissionLevel.WRITE,
|
||||
ActionType.SHELL_COMMAND: PermissionLevel.EXECUTE,
|
||||
ActionType.API_CALL: PermissionLevel.EXECUTE,
|
||||
ActionType.GIT_OPERATION: PermissionLevel.WRITE,
|
||||
ActionType.LLM_CALL: PermissionLevel.EXECUTE,
|
||||
ActionType.NETWORK_REQUEST: PermissionLevel.READ,
|
||||
ActionType.TOOL_CALL: PermissionLevel.EXECUTE,
|
||||
}
|
||||
|
||||
required_level = level_map.get(action.action_type, PermissionLevel.EXECUTE)
|
||||
|
||||
# Determine resource type from action
|
||||
resource_type_map = {
|
||||
ActionType.FILE_READ: ResourceType.FILE,
|
||||
ActionType.FILE_WRITE: ResourceType.FILE,
|
||||
ActionType.FILE_DELETE: ResourceType.FILE,
|
||||
ActionType.DATABASE_QUERY: ResourceType.DATABASE,
|
||||
ActionType.DATABASE_MUTATE: ResourceType.DATABASE,
|
||||
ActionType.SHELL_COMMAND: ResourceType.SHELL,
|
||||
ActionType.API_CALL: ResourceType.API,
|
||||
ActionType.GIT_OPERATION: ResourceType.GIT,
|
||||
ActionType.LLM_CALL: ResourceType.LLM,
|
||||
ActionType.NETWORK_REQUEST: ResourceType.NETWORK,
|
||||
}
|
||||
|
||||
resource_type = resource_type_map.get(action.action_type, ResourceType.CUSTOM)
|
||||
resource = action.resource or action.tool_name or "*"
|
||||
|
||||
return await self.check(
|
||||
agent_id=action.metadata.agent_id,
|
||||
resource=resource,
|
||||
resource_type=resource_type,
|
||||
required_level=required_level,
|
||||
)
|
||||
|
||||
async def require_permission(
|
||||
self,
|
||||
agent_id: str,
|
||||
resource: str,
|
||||
resource_type: ResourceType,
|
||||
required_level: PermissionLevel,
|
||||
) -> None:
|
||||
"""
|
||||
Require permission or raise exception.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
resource: Resource to access
|
||||
resource_type: Type of resource
|
||||
required_level: Required permission level
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If permission is denied
|
||||
"""
|
||||
if not await self.check(agent_id, resource, resource_type, required_level):
|
||||
raise PermissionDeniedError(
|
||||
f"Permission denied: {resource}",
|
||||
action_type=None,
|
||||
resource=resource,
|
||||
required_permission=required_level.value,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def list_grants(
|
||||
self,
|
||||
agent_id: str | None = None,
|
||||
resource_type: ResourceType | None = None,
|
||||
) -> list[PermissionGrant]:
|
||||
"""
|
||||
List permission grants.
|
||||
|
||||
Args:
|
||||
agent_id: Optional filter by agent
|
||||
resource_type: Optional filter by resource type
|
||||
|
||||
Returns:
|
||||
List of matching grants
|
||||
"""
|
||||
await self._cleanup_expired()
|
||||
|
||||
async with self._lock:
|
||||
grants = list(self._grants)
|
||||
|
||||
if agent_id:
|
||||
grants = [g for g in grants if g.agent_id == agent_id]
|
||||
|
||||
if resource_type:
|
||||
grants = [g for g in grants if g.resource_type == resource_type]
|
||||
|
||||
return grants
|
||||
|
||||
def set_default_permission(
|
||||
self,
|
||||
resource_type: ResourceType,
|
||||
level: PermissionLevel,
|
||||
) -> None:
|
||||
"""
|
||||
Set the default permission level for a resource type.
|
||||
|
||||
Args:
|
||||
resource_type: Type of resource
|
||||
level: Default permission level
|
||||
"""
|
||||
self._default_permissions[resource_type] = level
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
"""Remove expired grants."""
|
||||
async with self._lock:
|
||||
original_count = len(self._grants)
|
||||
self._grants = [g for g in self._grants if not g.is_expired()]
|
||||
removed = original_count - len(self._grants)
|
||||
|
||||
if removed:
|
||||
logger.debug("Cleaned up %d expired permission grants", removed)
|
||||
1
backend/app/services/safety/policies/__init__.py
Normal file
1
backend/app/services/safety/policies/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
5
backend/app/services/safety/rollback/__init__.py
Normal file
5
backend/app/services/safety/rollback/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Rollback management for agent actions."""
|
||||
|
||||
from .manager import RollbackManager, TransactionContext
|
||||
|
||||
__all__ = ["RollbackManager", "TransactionContext"]
|
||||
417
backend/app/services/safety/rollback/manager.py
Normal file
417
backend/app/services/safety/rollback/manager.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Rollback Manager
|
||||
|
||||
Manages checkpoints and rollback operations for agent actions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import RollbackError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
RollbackResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileCheckpoint:
|
||||
"""Stores file state for rollback."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_path: str,
|
||||
original_content: bytes | None,
|
||||
existed: bool,
|
||||
) -> None:
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.file_path = file_path
|
||||
self.original_content = original_content
|
||||
self.existed = existed
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class RollbackManager:
|
||||
"""
|
||||
Manages checkpoints and rollback operations.
|
||||
|
||||
Features:
|
||||
- File system checkpoints
|
||||
- Transaction wrapping for actions
|
||||
- Automatic checkpoint for destructive actions
|
||||
- Rollback triggers on failure
|
||||
- Checkpoint expiration and cleanup
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str | None = None,
|
||||
retention_hours: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the RollbackManager.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory for storing checkpoint data
|
||||
retention_hours: Hours to retain checkpoints
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._checkpoint_dir = Path(checkpoint_dir or config.checkpoint_dir)
|
||||
self._retention_hours = retention_hours or config.checkpoint_retention_hours
|
||||
|
||||
self._checkpoints: dict[str, Checkpoint] = {}
|
||||
self._file_checkpoints: dict[str, list[FileCheckpoint]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Ensure checkpoint directory exists
|
||||
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def create_checkpoint(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
checkpoint_type: CheckpointType = CheckpointType.COMPOSITE,
|
||||
description: str | None = None,
|
||||
) -> Checkpoint:
|
||||
"""
|
||||
Create a checkpoint before an action.
|
||||
|
||||
Args:
|
||||
action: The action to checkpoint for
|
||||
checkpoint_type: Type of checkpoint
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
checkpoint_id = str(uuid4())
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=checkpoint_id,
|
||||
checkpoint_type=checkpoint_type,
|
||||
action_id=action.id,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(hours=self._retention_hours),
|
||||
data={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
},
|
||||
description=description or f"Checkpoint for {action.tool_name}",
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._checkpoints[checkpoint_id] = checkpoint
|
||||
self._file_checkpoints[checkpoint_id] = []
|
||||
|
||||
logger.info(
|
||||
"Created checkpoint %s for action %s",
|
||||
checkpoint_id,
|
||||
action.id,
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
async def checkpoint_file(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Store current state of a file for checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
file_path: Path to the file
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if path.exists():
|
||||
content = path.read_bytes()
|
||||
existed = True
|
||||
else:
|
||||
content = None
|
||||
existed = False
|
||||
|
||||
file_checkpoint = FileCheckpoint(
|
||||
checkpoint_id=checkpoint_id,
|
||||
file_path=file_path,
|
||||
original_content=content,
|
||||
existed=existed,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
if checkpoint_id not in self._file_checkpoints:
|
||||
self._file_checkpoints[checkpoint_id] = []
|
||||
self._file_checkpoints[checkpoint_id].append(file_checkpoint)
|
||||
|
||||
logger.debug(
|
||||
"Stored file state for checkpoint %s: %s (existed=%s)",
|
||||
checkpoint_id,
|
||||
file_path,
|
||||
existed,
|
||||
)
|
||||
|
||||
async def checkpoint_files(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
file_paths: list[str],
|
||||
) -> None:
|
||||
"""
|
||||
Store current state of multiple files.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
file_paths: Paths to the files
|
||||
"""
|
||||
for path in file_paths:
|
||||
await self.checkpoint_file(checkpoint_id, path)
|
||||
|
||||
async def rollback(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
) -> RollbackResult:
|
||||
"""
|
||||
Rollback to a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
|
||||
Returns:
|
||||
Result of the rollback operation
|
||||
"""
|
||||
async with self._lock:
|
||||
checkpoint = self._checkpoints.get(checkpoint_id)
|
||||
if not checkpoint:
|
||||
raise RollbackError(
|
||||
f"Checkpoint not found: {checkpoint_id}",
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
if not checkpoint.is_valid:
|
||||
raise RollbackError(
|
||||
f"Checkpoint is no longer valid: {checkpoint_id}",
|
||||
checkpoint_id=checkpoint_id,
|
||||
)
|
||||
|
||||
file_checkpoints = self._file_checkpoints.get(checkpoint_id, [])
|
||||
|
||||
actions_rolled_back: list[str] = []
|
||||
failed_actions: list[str] = []
|
||||
|
||||
# Rollback file changes
|
||||
for fc in file_checkpoints:
|
||||
try:
|
||||
await self._rollback_file(fc)
|
||||
actions_rolled_back.append(f"file:{fc.file_path}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to rollback file %s: %s", fc.file_path, e)
|
||||
failed_actions.append(f"file:{fc.file_path}")
|
||||
|
||||
success = len(failed_actions) == 0
|
||||
|
||||
# Mark checkpoint as used
|
||||
async with self._lock:
|
||||
if checkpoint_id in self._checkpoints:
|
||||
self._checkpoints[checkpoint_id].is_valid = False
|
||||
|
||||
result = RollbackResult(
|
||||
checkpoint_id=checkpoint_id,
|
||||
success=success,
|
||||
actions_rolled_back=actions_rolled_back,
|
||||
failed_actions=failed_actions,
|
||||
error=None
|
||||
if success
|
||||
else f"Failed to rollback {len(failed_actions)} items",
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Rollback successful for checkpoint %s", checkpoint_id)
|
||||
else:
|
||||
logger.error(
|
||||
"Rollback partially failed for checkpoint %s: %d failures",
|
||||
checkpoint_id,
|
||||
len(failed_actions),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def discard_checkpoint(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Discard a checkpoint without rolling back.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint
|
||||
|
||||
Returns:
|
||||
True if checkpoint was found and discarded
|
||||
"""
|
||||
async with self._lock:
|
||||
if checkpoint_id in self._checkpoints:
|
||||
del self._checkpoints[checkpoint_id]
|
||||
if checkpoint_id in self._file_checkpoints:
|
||||
del self._file_checkpoints[checkpoint_id]
|
||||
logger.debug("Discarded checkpoint %s", checkpoint_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_checkpoint(self, checkpoint_id: str) -> Checkpoint | None:
|
||||
"""Get a checkpoint by ID."""
|
||||
async with self._lock:
|
||||
return self._checkpoints.get(checkpoint_id)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
action_id: str | None = None,
|
||||
include_expired: bool = False,
|
||||
) -> list[Checkpoint]:
|
||||
"""
|
||||
List checkpoints.
|
||||
|
||||
Args:
|
||||
action_id: Optional filter by action ID
|
||||
include_expired: Include expired checkpoints
|
||||
|
||||
Returns:
|
||||
List of checkpoints
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
async with self._lock:
|
||||
checkpoints = list(self._checkpoints.values())
|
||||
|
||||
if action_id:
|
||||
checkpoints = [c for c in checkpoints if c.action_id == action_id]
|
||||
|
||||
if not include_expired:
|
||||
checkpoints = [
|
||||
c for c in checkpoints if c.expires_at is None or c.expires_at > now
|
||||
]
|
||||
|
||||
return checkpoints
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Clean up expired checkpoints.
|
||||
|
||||
Returns:
|
||||
Number of checkpoints cleaned up
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
to_remove: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
for checkpoint_id, checkpoint in self._checkpoints.items():
|
||||
if checkpoint.expires_at and checkpoint.expires_at < now:
|
||||
to_remove.append(checkpoint_id)
|
||||
|
||||
for checkpoint_id in to_remove:
|
||||
del self._checkpoints[checkpoint_id]
|
||||
if checkpoint_id in self._file_checkpoints:
|
||||
del self._file_checkpoints[checkpoint_id]
|
||||
|
||||
if to_remove:
|
||||
logger.info("Cleaned up %d expired checkpoints", len(to_remove))
|
||||
|
||||
return len(to_remove)
|
||||
|
||||
async def _rollback_file(self, fc: FileCheckpoint) -> None:
|
||||
"""Rollback a single file to its checkpoint state."""
|
||||
path = Path(fc.file_path)
|
||||
|
||||
if fc.existed:
|
||||
# Restore original content
|
||||
if fc.original_content is not None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(fc.original_content)
|
||||
logger.debug("Restored file: %s", fc.file_path)
|
||||
else:
|
||||
# File didn't exist before - delete it
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
logger.debug("Deleted file (didn't exist before): %s", fc.file_path)
|
||||
|
||||
|
||||
class TransactionContext:
|
||||
"""
|
||||
Context manager for transactional action execution.
|
||||
|
||||
Usage:
|
||||
async with TransactionContext(rollback_manager, action) as tx:
|
||||
tx.checkpoint_file("/path/to/file")
|
||||
# Do work...
|
||||
# If exception occurs, automatic rollback
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: RollbackManager,
|
||||
action: ActionRequest,
|
||||
auto_rollback: bool = True,
|
||||
) -> None:
|
||||
self._manager = manager
|
||||
self._action = action
|
||||
self._auto_rollback = auto_rollback
|
||||
self._checkpoint: Checkpoint | None = None
|
||||
self._committed = False
|
||||
|
||||
async def __aenter__(self) -> "TransactionContext":
|
||||
self._checkpoint = await self._manager.create_checkpoint(self._action)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type | None,
|
||||
exc_val: Exception | None,
|
||||
exc_tb: Any,
|
||||
) -> bool:
|
||||
if exc_val is not None and self._auto_rollback and not self._committed:
|
||||
# Exception occurred - rollback
|
||||
if self._checkpoint:
|
||||
try:
|
||||
await self._manager.rollback(self._checkpoint.id)
|
||||
logger.info(
|
||||
"Auto-rollback completed for action %s",
|
||||
self._action.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Auto-rollback failed: %s", e)
|
||||
elif self._committed and self._checkpoint:
|
||||
# Committed - discard checkpoint
|
||||
await self._manager.discard_checkpoint(self._checkpoint.id)
|
||||
|
||||
return False # Don't suppress the exception
|
||||
|
||||
@property
|
||||
def checkpoint_id(self) -> str | None:
|
||||
"""Get the checkpoint ID."""
|
||||
return self._checkpoint.id if self._checkpoint else None
|
||||
|
||||
async def checkpoint_file(self, file_path: str) -> None:
|
||||
"""Checkpoint a file for this transaction."""
|
||||
if self._checkpoint:
|
||||
await self._manager.checkpoint_file(self._checkpoint.id, file_path)
|
||||
|
||||
async def checkpoint_files(self, file_paths: list[str]) -> None:
|
||||
"""Checkpoint multiple files for this transaction."""
|
||||
if self._checkpoint:
|
||||
await self._manager.checkpoint_files(self._checkpoint.id, file_paths)
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Mark transaction as committed (no rollback on exit)."""
|
||||
self._committed = True
|
||||
|
||||
async def rollback(self) -> RollbackResult | None:
|
||||
"""Manually trigger rollback."""
|
||||
if self._checkpoint:
|
||||
return await self._manager.rollback(self._checkpoint.id)
|
||||
return None
|
||||
1
backend/app/services/safety/sandbox/__init__.py
Normal file
1
backend/app/services/safety/sandbox/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""${dir} module."""
|
||||
21
backend/app/services/safety/validation/__init__.py
Normal file
21
backend/app/services/safety/validation/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Action Validation Module
|
||||
|
||||
Pre-execution validation with rule engine.
|
||||
"""
|
||||
|
||||
from .validator import (
|
||||
ActionValidator,
|
||||
ValidationCache,
|
||||
create_allow_rule,
|
||||
create_approval_rule,
|
||||
create_deny_rule,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionValidator",
|
||||
"ValidationCache",
|
||||
"create_allow_rule",
|
||||
"create_approval_rule",
|
||||
"create_deny_rule",
|
||||
]
|
||||
441
backend/app/services/safety/validation/validator.py
Normal file
441
backend/app/services/safety/validation/validator.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Action Validator
|
||||
|
||||
Pre-execution validation with rule engine for action requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
ActionType,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationResult,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationCache:
|
||||
"""LRU cache for validation results."""
|
||||
|
||||
def __init__(self, max_size: int = 1000, ttl_seconds: int = 60) -> None:
|
||||
self._cache: OrderedDict[str, tuple[ValidationResult, float]] = OrderedDict()
|
||||
self._max_size = max_size
|
||||
self._ttl = ttl_seconds
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get(self, key: str) -> ValidationResult | None:
|
||||
"""Get cached validation result."""
|
||||
import time
|
||||
|
||||
async with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
|
||||
result, timestamp = self._cache[key]
|
||||
if time.time() - timestamp > self._ttl:
|
||||
del self._cache[key]
|
||||
return None
|
||||
|
||||
# Move to end (LRU)
|
||||
self._cache.move_to_end(key)
|
||||
return result
|
||||
|
||||
async def set(self, key: str, result: ValidationResult) -> None:
|
||||
"""Cache a validation result."""
|
||||
import time
|
||||
|
||||
async with self._lock:
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
else:
|
||||
if len(self._cache) >= self._max_size:
|
||||
self._cache.popitem(last=False)
|
||||
self._cache[key] = (result, time.time())
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
async with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ActionValidator:
|
||||
"""
|
||||
Validates actions against safety rules before execution.
|
||||
|
||||
Features:
|
||||
- Rule-based validation engine
|
||||
- Allow/deny/require-approval rules
|
||||
- Pattern matching for tools and resources
|
||||
- Validation result caching
|
||||
- Bypass capability for emergencies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_enabled: bool = True,
|
||||
cache_size: int = 1000,
|
||||
cache_ttl: int = 60,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the ActionValidator.
|
||||
|
||||
Args:
|
||||
cache_enabled: Whether to cache validation results
|
||||
cache_size: Maximum cache entries
|
||||
cache_ttl: Cache TTL in seconds
|
||||
"""
|
||||
self._rules: list[ValidationRule] = []
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache = ValidationCache(max_size=cache_size, ttl_seconds=cache_ttl)
|
||||
self._bypass_enabled = False
|
||||
self._bypass_reason: str | None = None
|
||||
|
||||
config = get_safety_config()
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache_ttl = config.validation_cache_ttl
|
||||
self._cache_size = config.validation_cache_size
|
||||
|
||||
def add_rule(self, rule: ValidationRule) -> None:
|
||||
"""
|
||||
Add a validation rule.
|
||||
|
||||
Args:
|
||||
rule: The rule to add
|
||||
"""
|
||||
self._rules.append(rule)
|
||||
# Re-sort by priority (higher first)
|
||||
self._rules.sort(key=lambda r: r.priority, reverse=True)
|
||||
logger.debug(
|
||||
"Added validation rule: %s (priority %d)", rule.name, rule.priority
|
||||
)
|
||||
|
||||
def remove_rule(self, rule_id: str) -> bool:
|
||||
"""
|
||||
Remove a validation rule by ID.
|
||||
|
||||
Args:
|
||||
rule_id: ID of the rule to remove
|
||||
|
||||
Returns:
|
||||
True if rule was found and removed
|
||||
"""
|
||||
for i, rule in enumerate(self._rules):
|
||||
if rule.id == rule_id:
|
||||
del self._rules[i]
|
||||
logger.debug("Removed validation rule: %s", rule_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_rules(self) -> None:
|
||||
"""Remove all validation rules."""
|
||||
self._rules.clear()
|
||||
|
||||
def load_rules_from_policy(self, policy: SafetyPolicy) -> None:
|
||||
"""
|
||||
Load validation rules from a safety policy.
|
||||
|
||||
Args:
|
||||
policy: The policy to load rules from
|
||||
"""
|
||||
# Clear existing rules
|
||||
self.clear_rules()
|
||||
|
||||
# Add rules from policy
|
||||
for rule in policy.validation_rules:
|
||||
self.add_rule(rule)
|
||||
|
||||
# Create implicit rules from policy settings
|
||||
|
||||
# Denied tools
|
||||
for i, pattern in enumerate(policy.denied_tools):
|
||||
self.add_rule(
|
||||
ValidationRule(
|
||||
name=f"deny_tool_{i}",
|
||||
description=f"Deny tool pattern: {pattern}",
|
||||
priority=100, # High priority for denials
|
||||
tool_patterns=[pattern],
|
||||
decision=SafetyDecision.DENY,
|
||||
reason=f"Tool matches denied pattern: {pattern}",
|
||||
)
|
||||
)
|
||||
|
||||
# Require approval patterns
|
||||
for i, pattern in enumerate(policy.require_approval_for):
|
||||
if pattern == "*":
|
||||
# All actions require approval
|
||||
self.add_rule(
|
||||
ValidationRule(
|
||||
name="require_approval_all",
|
||||
description="All actions require approval",
|
||||
priority=50,
|
||||
action_types=list(ActionType),
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reason="All actions require human approval",
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.add_rule(
|
||||
ValidationRule(
|
||||
name=f"require_approval_{i}",
|
||||
description=f"Require approval for: {pattern}",
|
||||
priority=50,
|
||||
tool_patterns=[pattern],
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reason=f"Action matches approval-required pattern: {pattern}",
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Loaded %d rules from policy: %s", len(self._rules), policy.name)
|
||||
|
||||
async def validate(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
policy: SafetyPolicy | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate an action against all rules.
|
||||
|
||||
Args:
|
||||
action: The action to validate
|
||||
policy: Optional policy override
|
||||
|
||||
Returns:
|
||||
ValidationResult with decision and details
|
||||
"""
|
||||
# Check bypass
|
||||
if self._bypass_enabled:
|
||||
logger.warning(
|
||||
"Validation bypass active: %s - allowing action %s",
|
||||
self._bypass_reason,
|
||||
action.id,
|
||||
)
|
||||
return ValidationResult(
|
||||
action_id=action.id,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
applied_rules=[],
|
||||
reasons=[f"Validation bypassed: {self._bypass_reason}"],
|
||||
)
|
||||
|
||||
# Check cache
|
||||
if self._cache_enabled:
|
||||
cache_key = self._get_cache_key(action)
|
||||
cached = await self._cache.get(cache_key)
|
||||
if cached:
|
||||
logger.debug("Using cached validation for action %s", action.id)
|
||||
return cached
|
||||
|
||||
# Load rules from policy if provided
|
||||
if policy and not self._rules:
|
||||
self.load_rules_from_policy(policy)
|
||||
|
||||
# Validate against rules
|
||||
applied_rules: list[str] = []
|
||||
reasons: list[str] = []
|
||||
final_decision = SafetyDecision.ALLOW
|
||||
approval_id: str | None = None
|
||||
|
||||
for rule in self._rules:
|
||||
if not rule.enabled:
|
||||
continue
|
||||
|
||||
if self._rule_matches(rule, action):
|
||||
applied_rules.append(rule.id)
|
||||
|
||||
if rule.reason:
|
||||
reasons.append(rule.reason)
|
||||
|
||||
# Handle decision priority
|
||||
if rule.decision == SafetyDecision.DENY:
|
||||
# Deny takes precedence
|
||||
final_decision = SafetyDecision.DENY
|
||||
break
|
||||
|
||||
elif rule.decision == SafetyDecision.REQUIRE_APPROVAL:
|
||||
# Upgrade to require approval
|
||||
if final_decision != SafetyDecision.DENY:
|
||||
final_decision = SafetyDecision.REQUIRE_APPROVAL
|
||||
|
||||
# If no rules matched and no explicit allow, default to allow
|
||||
if not applied_rules:
|
||||
reasons.append("No matching rules - default allow")
|
||||
|
||||
result = ValidationResult(
|
||||
action_id=action.id,
|
||||
decision=final_decision,
|
||||
applied_rules=applied_rules,
|
||||
reasons=reasons,
|
||||
approval_id=approval_id,
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if self._cache_enabled:
|
||||
cache_key = self._get_cache_key(action)
|
||||
await self._cache.set(cache_key, result)
|
||||
|
||||
return result
|
||||
|
||||
async def validate_batch(
|
||||
self,
|
||||
actions: list[ActionRequest],
|
||||
policy: SafetyPolicy | None = None,
|
||||
) -> list[ValidationResult]:
|
||||
"""
|
||||
Validate multiple actions.
|
||||
|
||||
Args:
|
||||
actions: Actions to validate
|
||||
policy: Optional policy override
|
||||
|
||||
Returns:
|
||||
List of validation results
|
||||
"""
|
||||
tasks = [self.validate(action, policy) for action in actions]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def enable_bypass(self, reason: str) -> None:
|
||||
"""
|
||||
Enable validation bypass (emergency use only).
|
||||
|
||||
Args:
|
||||
reason: Reason for enabling bypass
|
||||
"""
|
||||
logger.critical("Validation bypass enabled: %s", reason)
|
||||
self._bypass_enabled = True
|
||||
self._bypass_reason = reason
|
||||
|
||||
def disable_bypass(self) -> None:
|
||||
"""Disable validation bypass."""
|
||||
logger.info("Validation bypass disabled")
|
||||
self._bypass_enabled = False
|
||||
self._bypass_reason = None
|
||||
|
||||
async def clear_cache(self) -> None:
|
||||
"""Clear the validation cache."""
|
||||
await self._cache.clear()
|
||||
|
||||
def _rule_matches(self, rule: ValidationRule, action: ActionRequest) -> bool:
|
||||
"""Check if a rule matches an action."""
|
||||
# Check action types
|
||||
if rule.action_types:
|
||||
if action.action_type not in rule.action_types:
|
||||
return False
|
||||
|
||||
# Check tool patterns
|
||||
if rule.tool_patterns:
|
||||
if not action.tool_name:
|
||||
return False
|
||||
matched = False
|
||||
for pattern in rule.tool_patterns:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
return False
|
||||
|
||||
# Check resource patterns
|
||||
if rule.resource_patterns:
|
||||
if not action.resource:
|
||||
return False
|
||||
matched = False
|
||||
for pattern in rule.resource_patterns:
|
||||
if self._matches_pattern(action.resource, pattern):
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
return False
|
||||
|
||||
# Check agent IDs
|
||||
if rule.agent_ids:
|
||||
if action.metadata.agent_id not in rule.agent_ids:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _matches_pattern(self, value: str, pattern: str) -> bool:
|
||||
"""Check if value matches a pattern (supports wildcards)."""
|
||||
if pattern == "*":
|
||||
return True
|
||||
|
||||
# Use fnmatch for glob-style matching
|
||||
return fnmatch.fnmatch(value, pattern)
|
||||
|
||||
def _get_cache_key(self, action: ActionRequest) -> str:
|
||||
"""Generate a cache key for an action."""
|
||||
# Key based on action characteristics that affect validation
|
||||
key_parts = [
|
||||
action.action_type.value,
|
||||
action.tool_name or "",
|
||||
action.resource or "",
|
||||
action.metadata.agent_id,
|
||||
action.metadata.autonomy_level.value,
|
||||
]
|
||||
return ":".join(key_parts)
|
||||
|
||||
|
||||
# Module-level convenience functions
|
||||
|
||||
|
||||
def create_allow_rule(
|
||||
name: str,
|
||||
tool_patterns: list[str] | None = None,
|
||||
resource_patterns: list[str] | None = None,
|
||||
action_types: list[ActionType] | None = None,
|
||||
priority: int = 0,
|
||||
) -> ValidationRule:
|
||||
"""Create an allow rule."""
|
||||
return ValidationRule(
|
||||
name=name,
|
||||
tool_patterns=tool_patterns,
|
||||
resource_patterns=resource_patterns,
|
||||
action_types=action_types,
|
||||
decision=SafetyDecision.ALLOW,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
||||
def create_deny_rule(
|
||||
name: str,
|
||||
tool_patterns: list[str] | None = None,
|
||||
resource_patterns: list[str] | None = None,
|
||||
action_types: list[ActionType] | None = None,
|
||||
reason: str | None = None,
|
||||
priority: int = 100,
|
||||
) -> ValidationRule:
|
||||
"""Create a deny rule."""
|
||||
return ValidationRule(
|
||||
name=name,
|
||||
tool_patterns=tool_patterns,
|
||||
resource_patterns=resource_patterns,
|
||||
action_types=action_types,
|
||||
decision=SafetyDecision.DENY,
|
||||
reason=reason,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
|
||||
def create_approval_rule(
|
||||
name: str,
|
||||
tool_patterns: list[str] | None = None,
|
||||
resource_patterns: list[str] | None = None,
|
||||
action_types: list[ActionType] | None = None,
|
||||
reason: str | None = None,
|
||||
priority: int = 50,
|
||||
) -> ValidationRule:
|
||||
"""Create a require-approval rule."""
|
||||
return ValidationRule(
|
||||
name=name,
|
||||
tool_patterns=tool_patterns,
|
||||
resource_patterns=resource_patterns,
|
||||
action_types=action_types,
|
||||
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||
reason=reason,
|
||||
priority=priority,
|
||||
)
|
||||
@@ -1,11 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Ensure the project's virtualenv binaries are on PATH so commands like
|
||||
# 'uvicorn' work even when not prefixed by 'uv run'. This matches how uv
|
||||
# installs the env into /app/.venv in our containers.
|
||||
if [ -d "/app/.venv/bin" ]; then
|
||||
export PATH="/app/.venv/bin:$PATH"
|
||||
# Ensure the virtualenv binaries are on PATH. Dependencies are installed
|
||||
# to /opt/venv (not /app/.venv) to survive bind mounts in development.
|
||||
if [ -d "/opt/venv/bin" ]; then
|
||||
export PATH="/opt/venv/bin:$PATH"
|
||||
export VIRTUAL_ENV="/opt/venv"
|
||||
fi
|
||||
|
||||
# Only the backend service should run migrations and init_db
|
||||
|
||||
466
backend/tests/api/routes/test_context.py
Normal file
466
backend/tests/api/routes/test_context.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
Tests for Context Management API Routes.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.services.context import (
|
||||
AssembledContext,
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
ContextEngine,
|
||||
TokenBudget,
|
||||
)
|
||||
from app.services.mcp import MCPClientManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_client():
|
||||
"""Create a mock MCP client manager."""
|
||||
client = MagicMock(spec=MCPClientManager)
|
||||
client.is_initialized = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context_engine(mock_mcp_client):
|
||||
"""Create a mock ContextEngine."""
|
||||
engine = MagicMock(spec=ContextEngine)
|
||||
engine._mcp = mock_mcp_client
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_superuser():
|
||||
"""Create a mock superuser."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = "00000000-0000-0000-0000-000000000001"
|
||||
user.is_superuser = True
|
||||
user.email = "admin@example.com"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_mcp_client, mock_context_engine, mock_superuser):
|
||||
"""Create a FastAPI test client with mocked dependencies."""
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.api.routes.context import get_context_engine
|
||||
from app.services.mcp import get_mcp_client
|
||||
|
||||
# Override dependencies
|
||||
async def override_get_mcp_client():
|
||||
return mock_mcp_client
|
||||
|
||||
async def override_get_context_engine():
|
||||
return mock_context_engine
|
||||
|
||||
async def override_require_superuser():
|
||||
return mock_superuser
|
||||
|
||||
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
|
||||
app.dependency_overrides[get_context_engine] = override_get_context_engine
|
||||
app.dependency_overrides[require_superuser] = override_require_superuser
|
||||
|
||||
with patch("app.main.check_database_health", return_value=True):
|
||||
yield TestClient(app)
|
||||
|
||||
# Clean up
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestContextHealth:
|
||||
"""Tests for GET /context/health endpoint."""
|
||||
|
||||
def test_health_check_success(self, client, mock_context_engine, mock_mcp_client):
|
||||
"""Test context engine health check."""
|
||||
mock_context_engine.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"cache": {"hits": 10, "misses": 5},
|
||||
"settings": {"cache_enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "mcp_connected" in data
|
||||
assert "cache_enabled" in data
|
||||
|
||||
|
||||
class TestAssembleContext:
|
||||
"""Tests for POST /context/assemble endpoint."""
|
||||
|
||||
def test_assemble_context_success(self, client, mock_context_engine):
|
||||
"""Test successful context assembly."""
|
||||
# Create mock assembled context
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Assembled context content"
|
||||
mock_result.total_tokens = 500
|
||||
mock_result.context_count = 2
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 50.5
|
||||
mock_result.metadata = {}
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "What is the auth flow?",
|
||||
"model": "claude-3-sonnet",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["content"] == "Assembled context content"
|
||||
assert data["total_tokens"] == 500
|
||||
assert data["context_count"] == 2
|
||||
assert data["compressed"] is False
|
||||
assert "budget_used_percent" in data
|
||||
|
||||
def test_assemble_context_with_conversation(self, client, mock_context_engine):
|
||||
"""Test context assembly with conversation history."""
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Context with history"
|
||||
mock_result.total_tokens = 800
|
||||
mock_result.context_count = 1
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 30.0
|
||||
mock_result.metadata = {}
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "Continue the discussion",
|
||||
"conversation_history": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
call_args = mock_context_engine.assemble_context.call_args
|
||||
assert call_args.kwargs["conversation_history"] == [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
def test_assemble_context_with_tool_results(self, client, mock_context_engine):
|
||||
"""Test context assembly with tool results."""
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Context with tools"
|
||||
mock_result.total_tokens = 600
|
||||
mock_result.context_count = 1
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 25.0
|
||||
mock_result.metadata = {}
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "What did the search find?",
|
||||
"tool_results": [
|
||||
{
|
||||
"tool_name": "search_knowledge",
|
||||
"content": {"results": ["item1", "item2"]},
|
||||
"status": "success",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
call_args = mock_context_engine.assemble_context.call_args
|
||||
assert len(call_args.kwargs["tool_results"]) == 1
|
||||
|
||||
def test_assemble_context_timeout(self, client, mock_context_engine):
|
||||
"""Test context assembly timeout error."""
|
||||
mock_context_engine.assemble_context = AsyncMock(
|
||||
side_effect=AssemblyTimeoutError("Assembly exceeded 5000ms limit")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
|
||||
|
||||
def test_assemble_context_budget_exceeded(self, client, mock_context_engine):
|
||||
"""Test context assembly budget exceeded error."""
|
||||
mock_context_engine.assemble_context = AsyncMock(
|
||||
side_effect=BudgetExceededError("Token budget exceeded: 5000 > 4000")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
||||
|
||||
def test_assemble_context_validation_error(self, client):
|
||||
"""Test context assembly with invalid request."""
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={}, # Missing required fields
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestCountTokens:
|
||||
"""Tests for POST /context/count-tokens endpoint."""
|
||||
|
||||
def test_count_tokens_success(self, client, mock_context_engine):
|
||||
"""Test successful token counting."""
|
||||
mock_context_engine.count_tokens = AsyncMock(return_value=42)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={
|
||||
"content": "This is some test content.",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["token_count"] == 42
|
||||
assert data["model"] == "claude-3-sonnet"
|
||||
|
||||
def test_count_tokens_without_model(self, client, mock_context_engine):
|
||||
"""Test token counting without specifying model."""
|
||||
mock_context_engine.count_tokens = AsyncMock(return_value=100)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={"content": "Some content to count."},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["token_count"] == 100
|
||||
assert data["model"] is None
|
||||
|
||||
|
||||
class TestGetBudget:
|
||||
"""Tests for GET /context/budget/{model} endpoint."""
|
||||
|
||||
def test_get_budget_success(self, client, mock_context_engine):
|
||||
"""Test getting token budget for a model."""
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=100000,
|
||||
system=10000,
|
||||
knowledge=40000,
|
||||
conversation=30000,
|
||||
tools=10000,
|
||||
response_reserve=10000,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/budget/claude-3-opus")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["model"] == "claude-3-opus"
|
||||
assert data["total_tokens"] == 100000
|
||||
assert data["system_tokens"] == 10000
|
||||
assert data["knowledge_tokens"] == 40000
|
||||
|
||||
def test_get_budget_with_max_tokens(self, client, mock_context_engine):
|
||||
"""Test getting budget with custom max tokens."""
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=2000,
|
||||
system=200,
|
||||
knowledge=800,
|
||||
conversation=600,
|
||||
tools=200,
|
||||
response_reserve=200,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/budget/gpt-4?max_tokens=2000")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total_tokens"] == 2000
|
||||
|
||||
|
||||
class TestGetStats:
|
||||
"""Tests for GET /context/stats endpoint."""
|
||||
|
||||
def test_get_stats_success(self, client, mock_context_engine):
|
||||
"""Test getting engine statistics."""
|
||||
mock_context_engine.get_stats = AsyncMock(
|
||||
return_value={
|
||||
"cache": {
|
||||
"hits": 100,
|
||||
"misses": 25,
|
||||
"hit_rate": 0.8,
|
||||
},
|
||||
"settings": {
|
||||
"compression_threshold": 0.9,
|
||||
"max_assembly_time_ms": 5000,
|
||||
"cache_enabled": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/stats")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["cache"]["hits"] == 100
|
||||
assert data["settings"]["cache_enabled"] is True
|
||||
|
||||
|
||||
class TestInvalidateCache:
|
||||
"""Tests for POST /context/cache/invalidate endpoint."""
|
||||
|
||||
def test_invalidate_cache_by_project(self, client, mock_context_engine):
|
||||
"""Test cache invalidation by project ID."""
|
||||
mock_context_engine.invalidate_cache = AsyncMock(return_value=5)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/cache/invalidate?project_id=test-project"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
mock_context_engine.invalidate_cache.assert_called_once()
|
||||
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "test-project"
|
||||
|
||||
def test_invalidate_cache_by_pattern(self, client, mock_context_engine):
|
||||
"""Test cache invalidation by pattern."""
|
||||
mock_context_engine.invalidate_cache = AsyncMock(return_value=10)
|
||||
|
||||
response = client.post("/api/v1/context/cache/invalidate?pattern=*auth*")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
mock_context_engine.invalidate_cache.assert_called_once()
|
||||
call_kwargs = mock_context_engine.invalidate_cache.call_args.kwargs
|
||||
assert call_kwargs["pattern"] == "*auth*"
|
||||
|
||||
def test_invalidate_cache_all(self, client, mock_context_engine):
|
||||
"""Test invalidating all cache entries."""
|
||||
mock_context_engine.invalidate_cache = AsyncMock(return_value=100)
|
||||
|
||||
response = client.post("/api/v1/context/cache/invalidate")
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
|
||||
class TestContextEndpointsEdgeCases:
|
||||
"""Edge case tests for Context endpoints."""
|
||||
|
||||
def test_context_content_type(self, client, mock_context_engine):
|
||||
"""Test that endpoints return JSON content type."""
|
||||
mock_context_engine.get_stats = AsyncMock(
|
||||
return_value={"cache": {}, "settings": {}}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/context/health")
|
||||
|
||||
assert "application/json" in response.headers["content-type"]
|
||||
|
||||
def test_assemble_context_with_knowledge_query(self, client, mock_context_engine):
|
||||
"""Test context assembly with knowledge base query."""
|
||||
mock_result = MagicMock(spec=AssembledContext)
|
||||
mock_result.content = "Context with knowledge"
|
||||
mock_result.total_tokens = 1000
|
||||
mock_result.context_count = 3
|
||||
mock_result.excluded_count = 0
|
||||
mock_result.assembly_time_ms = 100.0
|
||||
mock_result.metadata = {
|
||||
"compressed_contexts": 1
|
||||
} # Indicates compression happened
|
||||
|
||||
mock_context_engine.assemble_context = AsyncMock(return_value=mock_result)
|
||||
mock_context_engine.get_budget_for_model = AsyncMock(
|
||||
return_value=TokenBudget(
|
||||
total=4000,
|
||||
system=500,
|
||||
knowledge=1500,
|
||||
conversation=1000,
|
||||
tools=500,
|
||||
response_reserve=500,
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "How does authentication work?",
|
||||
"knowledge_query": "authentication flow implementation",
|
||||
"knowledge_limit": 5,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
call_kwargs = mock_context_engine.assemble_context.call_args.kwargs
|
||||
assert call_kwargs["knowledge_query"] == "authentication flow implementation"
|
||||
assert call_kwargs["knowledge_limit"] == 5
|
||||
@@ -44,8 +44,8 @@ def mock_superuser():
|
||||
@pytest.fixture
|
||||
def client(mock_mcp_client, mock_superuser):
|
||||
"""Create a FastAPI test client with mocked dependencies."""
|
||||
from app.api.routes.mcp import get_mcp_client
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.api.routes.mcp import get_mcp_client
|
||||
|
||||
# Override dependencies
|
||||
async def override_get_mcp_client():
|
||||
|
||||
646
backend/tests/e2e/test_agent_workflows.py
Normal file
646
backend/tests/e2e/test_agent_workflows.py
Normal file
@@ -0,0 +1,646 @@
|
||||
"""
|
||||
Agent E2E Workflow Tests.
|
||||
|
||||
Tests complete workflows for AI agents including:
|
||||
- Agent type management (admin-only)
|
||||
- Agent instance spawning and lifecycle
|
||||
- Agent status transitions (pause/resume/terminate)
|
||||
- Authorization and access control
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
class TestAgentTypesAdminWorkflows:
|
||||
"""Test agent type management (admin-only operations)."""
|
||||
|
||||
async def test_create_agent_type_requires_superuser(self, e2e_client):
|
||||
"""Test that creating agent types requires superuser privileges."""
|
||||
# Register regular user
|
||||
email = f"regular-{uuid4().hex[:8]}@example.com"
|
||||
password = "RegularPass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Regular",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Try to create agent type
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"name": "Test Agent",
|
||||
"slug": f"test-agent-{uuid4().hex[:8]}",
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_superuser_can_create_agent_type(self, e2e_client, e2e_superuser):
|
||||
"""Test that superuser can create and manage agent types."""
|
||||
slug = f"test-type-{uuid4().hex[:8]}"
|
||||
|
||||
# Create agent type
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Product Owner Agent",
|
||||
"slug": slug,
|
||||
"description": "A product owner agent for requirements gathering",
|
||||
"expertise": ["requirements", "user_stories", "prioritization"],
|
||||
"personality_prompt": "You are a product owner focused on delivering value.",
|
||||
"primary_model": "claude-3-opus",
|
||||
"fallback_models": ["claude-3-sonnet"],
|
||||
"model_params": {"temperature": 0.7, "max_tokens": 4000},
|
||||
"mcp_servers": ["knowledge-base"],
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert create_resp.status_code == 201, f"Failed: {create_resp.text}"
|
||||
agent_type = create_resp.json()
|
||||
|
||||
assert agent_type["name"] == "Product Owner Agent"
|
||||
assert agent_type["slug"] == slug
|
||||
assert agent_type["primary_model"] == "claude-3-opus"
|
||||
assert agent_type["is_active"] is True
|
||||
assert "requirements" in agent_type["expertise"]
|
||||
|
||||
async def test_list_agent_types_public(self, e2e_client, e2e_superuser):
|
||||
"""Test that any authenticated user can list agent types."""
|
||||
# First create an agent type as superuser
|
||||
slug = f"list-test-{uuid4().hex[:8]}"
|
||||
await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": f"List Test Agent {slug}",
|
||||
"slug": slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
# Register regular user
|
||||
email = f"lister-{uuid4().hex[:8]}@example.com"
|
||||
password = "ListerPass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "List",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# List agent types as regular user
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/agent-types",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert list_resp.status_code == 200
|
||||
data = list_resp.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
assert data["pagination"]["total"] >= 1
|
||||
|
||||
async def test_get_agent_type_by_slug(self, e2e_client, e2e_superuser):
|
||||
"""Test getting agent type by slug."""
|
||||
slug = f"slug-test-{uuid4().hex[:8]}"
|
||||
|
||||
# Create agent type
|
||||
await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": f"Slug Test {slug}",
|
||||
"slug": slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
# Get by slug (route is /slug/{slug}, not /by-slug/{slug})
|
||||
get_resp = await e2e_client.get(
|
||||
f"/api/v1/agent-types/slug/{slug}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert get_resp.status_code == 200
|
||||
data = get_resp.json()
|
||||
assert data["slug"] == slug
|
||||
|
||||
async def test_update_agent_type(self, e2e_client, e2e_superuser):
|
||||
"""Test updating an agent type."""
|
||||
slug = f"update-test-{uuid4().hex[:8]}"
|
||||
|
||||
# Create agent type
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"slug": slug,
|
||||
"personality_prompt": "Original prompt.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = create_resp.json()["id"]
|
||||
|
||||
# Update agent type
|
||||
update_resp = await e2e_client.patch(
|
||||
f"/api/v1/agent-types/{agent_type_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Updated Name",
|
||||
"description": "Added description",
|
||||
},
|
||||
)
|
||||
|
||||
assert update_resp.status_code == 200
|
||||
updated = update_resp.json()
|
||||
assert updated["name"] == "Updated Name"
|
||||
assert updated["description"] == "Added description"
|
||||
assert updated["personality_prompt"] == "Original prompt." # Unchanged
|
||||
|
||||
|
||||
class TestAgentInstanceWorkflows:
|
||||
"""Test agent instance spawning and lifecycle."""
|
||||
|
||||
async def test_spawn_agent_workflow(self, e2e_client, e2e_superuser):
|
||||
"""Test complete workflow: create type -> create project -> spawn agent."""
|
||||
# 1. Create agent type as superuser
|
||||
type_slug = f"spawn-test-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Spawn Test Agent",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "You are a helpful agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
assert type_resp.status_code == 201
|
||||
agent_type = type_resp.json()
|
||||
agent_type_id = agent_type["id"]
|
||||
|
||||
# 2. Create a project (superuser can create projects too)
|
||||
project_slug = f"spawn-test-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Spawn Test Project", "slug": project_slug},
|
||||
)
|
||||
assert project_resp.status_code == 201
|
||||
project = project_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# 3. Spawn agent instance
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "My PO Agent",
|
||||
},
|
||||
)
|
||||
|
||||
assert spawn_resp.status_code == 201, f"Failed: {spawn_resp.text}"
|
||||
agent = spawn_resp.json()
|
||||
|
||||
assert agent["name"] == "My PO Agent"
|
||||
assert agent["status"] == "idle"
|
||||
assert agent["project_id"] == project_id
|
||||
assert agent["agent_type_id"] == agent_type_id
|
||||
|
||||
async def test_list_project_agents(self, e2e_client, e2e_superuser):
|
||||
"""Test listing agents in a project."""
|
||||
# Setup: Create agent type and project
|
||||
type_slug = f"list-agents-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "List Agents Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
project_slug = f"list-agents-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "List Agents Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
# Spawn multiple agents
|
||||
for i in range(3):
|
||||
await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": f"Agent {i + 1}",
|
||||
},
|
||||
)
|
||||
|
||||
# List agents
|
||||
list_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert list_resp.status_code == 200
|
||||
data = list_resp.json()
|
||||
assert data["pagination"]["total"] == 3
|
||||
assert len(data["data"]) == 3
|
||||
|
||||
|
||||
class TestAgentLifecycle:
|
||||
"""Test agent lifecycle operations (pause/resume/terminate)."""
|
||||
|
||||
async def test_agent_pause_and_resume(self, e2e_client, e2e_superuser):
|
||||
"""Test pausing and resuming an agent."""
|
||||
# Setup: Create agent type, project, and agent
|
||||
type_slug = f"pause-test-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Pause Test Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
project_slug = f"pause-test-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Pause Test Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "Pausable Agent",
|
||||
},
|
||||
)
|
||||
agent_id = spawn_resp.json()["id"]
|
||||
assert spawn_resp.json()["status"] == "idle"
|
||||
|
||||
# Pause agent
|
||||
pause_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}/pause",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
assert pause_resp.status_code == 200, f"Failed: {pause_resp.text}"
|
||||
assert pause_resp.json()["status"] == "paused"
|
||||
|
||||
# Resume agent
|
||||
resume_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
assert resume_resp.status_code == 200, f"Failed: {resume_resp.text}"
|
||||
assert resume_resp.json()["status"] == "idle"
|
||||
|
||||
async def test_agent_terminate(self, e2e_client, e2e_superuser):
|
||||
"""Test terminating an agent."""
|
||||
# Setup
|
||||
type_slug = f"terminate-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Terminate Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
project_slug = f"terminate-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Terminate Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "To Be Terminated",
|
||||
},
|
||||
)
|
||||
agent_id = spawn_resp.json()["id"]
|
||||
|
||||
# Terminate agent (returns MessageResponse, not agent status)
|
||||
terminate_resp = await e2e_client.delete(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert terminate_resp.status_code == 200
|
||||
assert "message" in terminate_resp.json()
|
||||
|
||||
# Verify terminated agent cannot be resumed (returns 400 or 422)
|
||||
resume_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
assert resume_resp.status_code in [400, 422] # Invalid transition
|
||||
|
||||
|
||||
class TestAgentAccessControl:
|
||||
"""Test agent access control and authorization."""
|
||||
|
||||
async def test_user_cannot_access_other_project_agents(
|
||||
self, e2e_client, e2e_superuser
|
||||
):
|
||||
"""Test that users cannot access agents in projects they don't own."""
|
||||
# Superuser creates agent type
|
||||
type_slug = f"access-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Access Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
# Superuser creates project and spawns agent
|
||||
project_slug = f"protected-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Protected Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "Protected Agent",
|
||||
},
|
||||
)
|
||||
agent_id = spawn_resp.json()["id"]
|
||||
|
||||
# Create a different user
|
||||
email = f"other-user-{uuid4().hex[:8]}@example.com"
|
||||
password = "OtherPass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Other",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
other_tokens = login_resp.json()
|
||||
|
||||
# Other user tries to access the agent
|
||||
get_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||
headers={"Authorization": f"Bearer {other_tokens['access_token']}"},
|
||||
)
|
||||
|
||||
# Should be forbidden or not found
|
||||
assert get_resp.status_code in [403, 404]
|
||||
|
||||
async def test_cannot_spawn_with_inactive_agent_type(
|
||||
self, e2e_client, e2e_superuser
|
||||
):
|
||||
"""Test that agents cannot be spawned from inactive agent types."""
|
||||
# Create agent type
|
||||
type_slug = f"inactive-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Inactive Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
# Deactivate the agent type
|
||||
await e2e_client.patch(
|
||||
f"/api/v1/agent-types/{agent_type_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"is_active": False},
|
||||
)
|
||||
|
||||
# Create project
|
||||
project_slug = f"inactive-spawn-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Inactive Spawn Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
# Try to spawn agent with inactive type
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "Should Fail",
|
||||
},
|
||||
)
|
||||
|
||||
# 422 is correct for validation errors per REST conventions
|
||||
assert spawn_resp.status_code == 422
|
||||
|
||||
|
||||
class TestAgentMetrics:
|
||||
"""Test agent metrics endpoint."""
|
||||
|
||||
async def test_get_agent_metrics(self, e2e_client, e2e_superuser):
|
||||
"""Test retrieving agent metrics."""
|
||||
# Setup
|
||||
type_slug = f"metrics-type-{uuid4().hex[:8]}"
|
||||
type_resp = await e2e_client.post(
|
||||
"/api/v1/agent-types",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Metrics Type",
|
||||
"slug": type_slug,
|
||||
"personality_prompt": "Test agent.",
|
||||
"primary_model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
agent_type_id = type_resp.json()["id"]
|
||||
|
||||
project_slug = f"metrics-project-{uuid4().hex[:8]}"
|
||||
project_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Metrics Project", "slug": project_slug},
|
||||
)
|
||||
project_id = project_resp.json()["id"]
|
||||
|
||||
spawn_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/agents",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
"name": "Metrics Agent",
|
||||
},
|
||||
)
|
||||
agent_id = spawn_resp.json()["id"]
|
||||
|
||||
# Get metrics
|
||||
metrics_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/agents/{agent_id}/metrics",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert metrics_resp.status_code == 200
|
||||
metrics = metrics_resp.json()
|
||||
|
||||
# Verify AgentInstanceMetrics structure
|
||||
assert "total_instances" in metrics
|
||||
assert "active_instances" in metrics
|
||||
assert "idle_instances" in metrics
|
||||
assert "total_tasks_completed" in metrics
|
||||
assert "total_tokens_used" in metrics
|
||||
assert "total_cost_incurred" in metrics
|
||||
460
backend/tests/e2e/test_mcp_workflows.py
Normal file
460
backend/tests/e2e/test_mcp_workflows.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
MCP and Context Engine E2E Workflow Tests.
|
||||
|
||||
Tests complete workflows involving MCP servers and the Context Engine
|
||||
against real PostgreSQL. These tests verify:
|
||||
- MCP server listing and tool discovery
|
||||
- Context engine operations
|
||||
- Admin-only MCP operations with proper authentication
|
||||
- Error handling for MCP operations
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
class TestMCPServerDiscovery:
|
||||
"""Test MCP server listing and discovery workflows."""
|
||||
|
||||
async def test_list_mcp_servers(self, e2e_client):
|
||||
"""Test listing MCP servers returns expected configuration."""
|
||||
response = await e2e_client.get("/api/v1/mcp/servers")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
# Should have servers configured
|
||||
assert "servers" in data
|
||||
assert "total" in data
|
||||
assert isinstance(data["servers"], list)
|
||||
|
||||
# Should have at least llm-gateway and knowledge-base
|
||||
server_names = [s["name"] for s in data["servers"]]
|
||||
assert "llm-gateway" in server_names
|
||||
assert "knowledge-base" in server_names
|
||||
|
||||
async def test_list_all_mcp_tools(self, e2e_client):
|
||||
"""Test listing all tools from all MCP servers."""
|
||||
response = await e2e_client.get("/api/v1/mcp/tools")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "tools" in data
|
||||
assert "total" in data
|
||||
assert isinstance(data["tools"], list)
|
||||
|
||||
async def test_mcp_health_check(self, e2e_client):
|
||||
"""Test MCP health check returns server status."""
|
||||
response = await e2e_client.get("/api/v1/mcp/health")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "servers" in data
|
||||
assert "healthy_count" in data
|
||||
assert "unhealthy_count" in data
|
||||
assert "total" in data
|
||||
|
||||
async def test_list_circuit_breakers(self, e2e_client):
|
||||
"""Test listing circuit breaker status."""
|
||||
response = await e2e_client.get("/api/v1/mcp/circuit-breakers")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "circuit_breakers" in data
|
||||
assert isinstance(data["circuit_breakers"], list)
|
||||
|
||||
|
||||
class TestMCPServerTools:
|
||||
"""Test MCP server tool listing."""
|
||||
|
||||
async def test_list_llm_gateway_tools(self, e2e_client):
|
||||
"""Test listing tools from LLM Gateway server."""
|
||||
response = await e2e_client.get("/api/v1/mcp/servers/llm-gateway/tools")
|
||||
|
||||
# May return 200 with tools or 404 if server not connected
|
||||
assert response.status_code in [200, 404, 502]
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
assert "total" in data
|
||||
|
||||
async def test_list_knowledge_base_tools(self, e2e_client):
|
||||
"""Test listing tools from Knowledge Base server."""
|
||||
response = await e2e_client.get("/api/v1/mcp/servers/knowledge-base/tools")
|
||||
|
||||
# May return 200 with tools or 404/502 if server not connected
|
||||
assert response.status_code in [200, 404, 502]
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
assert "total" in data
|
||||
|
||||
async def test_invalid_server_returns_404(self, e2e_client):
|
||||
"""Test that invalid server name returns 404."""
|
||||
response = await e2e_client.get("/api/v1/mcp/servers/nonexistent-server/tools")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestContextEngineWorkflows:
|
||||
"""Test Context Engine operations."""
|
||||
|
||||
async def test_context_engine_health(self, e2e_client):
|
||||
"""Test context engine health endpoint."""
|
||||
response = await e2e_client.get("/api/v1/context/health")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "healthy"
|
||||
assert "mcp_connected" in data
|
||||
assert "cache_enabled" in data
|
||||
|
||||
async def test_get_token_budget_claude_sonnet(self, e2e_client):
|
||||
"""Test getting token budget for Claude 3 Sonnet."""
|
||||
response = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert data["model"] == "claude-3-sonnet"
|
||||
assert "total_tokens" in data
|
||||
assert "system_tokens" in data
|
||||
assert "knowledge_tokens" in data
|
||||
assert "conversation_tokens" in data
|
||||
assert "tool_tokens" in data
|
||||
assert "response_reserve" in data
|
||||
|
||||
# Verify budget allocation makes sense
|
||||
assert data["total_tokens"] > 0
|
||||
total_allocated = (
|
||||
data["system_tokens"]
|
||||
+ data["knowledge_tokens"]
|
||||
+ data["conversation_tokens"]
|
||||
+ data["tool_tokens"]
|
||||
+ data["response_reserve"]
|
||||
)
|
||||
assert total_allocated <= data["total_tokens"]
|
||||
|
||||
async def test_get_token_budget_with_custom_max(self, e2e_client):
|
||||
"""Test getting token budget with custom max tokens."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/context/budget/claude-3-sonnet",
|
||||
params={"max_tokens": 50000},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert data["model"] == "claude-3-sonnet"
|
||||
# Custom max should be respected or capped
|
||||
assert data["total_tokens"] <= 50000
|
||||
|
||||
async def test_count_tokens(self, e2e_client):
|
||||
"""Test token counting endpoint."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={
|
||||
"content": "Hello, this is a test message for token counting.",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "token_count" in data
|
||||
assert data["token_count"] > 0
|
||||
assert data["model"] == "claude-3-sonnet"
|
||||
|
||||
|
||||
class TestAdminMCPOperations:
|
||||
"""Test admin-only MCP operations require authentication."""
|
||||
|
||||
async def test_tool_call_requires_auth(self, e2e_client):
|
||||
"""Test that tool execution requires authentication."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/mcp/call",
|
||||
json={
|
||||
"server": "llm-gateway",
|
||||
"tool": "count_tokens",
|
||||
"arguments": {"text": "test"},
|
||||
},
|
||||
)
|
||||
|
||||
# Should require authentication
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_circuit_reset_requires_auth(self, e2e_client):
|
||||
"""Test that circuit breaker reset requires authentication."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/mcp/circuit-breakers/llm-gateway/reset"
|
||||
)
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_server_reconnect_requires_auth(self, e2e_client):
|
||||
"""Test that server reconnect requires authentication."""
|
||||
response = await e2e_client.post("/api/v1/mcp/servers/llm-gateway/reconnect")
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_context_stats_requires_auth(self, e2e_client):
|
||||
"""Test that context stats requires authentication."""
|
||||
response = await e2e_client.get("/api/v1/context/stats")
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_context_assemble_requires_auth(self, e2e_client):
|
||||
"""Test that context assembly requires authentication."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test query",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_cache_invalidate_requires_auth(self, e2e_client):
|
||||
"""Test that cache invalidation requires authentication."""
|
||||
response = await e2e_client.post("/api/v1/context/cache/invalidate")
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
|
||||
class TestAdminMCPWithAuthentication:
|
||||
"""Test admin MCP operations with superuser authentication."""
|
||||
|
||||
async def test_superuser_can_get_context_stats(self, e2e_client, e2e_superuser):
|
||||
"""Test that superuser can get context engine stats."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/context/stats",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "cache" in data
|
||||
assert "settings" in data
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Requires MCP servers (llm-gateway, knowledge-base) to be running"
|
||||
)
|
||||
async def test_superuser_can_assemble_context(self, e2e_client, e2e_superuser):
|
||||
"""Test that superuser can assemble context."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/assemble",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"project_id": f"test-project-{uuid4().hex[:8]}",
|
||||
"agent_id": f"test-agent-{uuid4().hex[:8]}",
|
||||
"query": "What is the status of the project?",
|
||||
"model": "claude-3-sonnet",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"compress": True,
|
||||
"use_cache": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "content" in data
|
||||
assert "total_tokens" in data
|
||||
assert "context_count" in data
|
||||
assert "budget_used_percent" in data
|
||||
assert "metadata" in data
|
||||
|
||||
async def test_superuser_can_invalidate_cache(self, e2e_client, e2e_superuser):
|
||||
"""Test that superuser can invalidate cache."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/cache/invalidate",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"project_id": "test-project"},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
async def test_regular_user_cannot_access_admin_operations(self, e2e_client):
|
||||
"""Test that regular (non-superuser) cannot access admin operations."""
|
||||
email = f"regular-{uuid4().hex[:8]}@example.com"
|
||||
password = "RegularUser123!"
|
||||
|
||||
# Register regular user
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Regular",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Login
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Try to access admin endpoint
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/context/stats",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
# Should be forbidden for non-superuser
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestMCPInputValidation:
|
||||
"""Test input validation for MCP endpoints."""
|
||||
|
||||
async def test_server_name_max_length(self, e2e_client):
|
||||
"""Test that server name has max length validation."""
|
||||
long_name = "a" * 100 # Exceeds 64 char limit
|
||||
|
||||
response = await e2e_client.get(f"/api/v1/mcp/servers/{long_name}/tools")
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_server_name_invalid_characters(self, e2e_client):
|
||||
"""Test that server name rejects invalid characters."""
|
||||
invalid_name = "server@name!invalid"
|
||||
|
||||
response = await e2e_client.get(f"/api/v1/mcp/servers/{invalid_name}/tools")
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_token_count_empty_content(self, e2e_client):
|
||||
"""Test token counting with empty content."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={"content": ""},
|
||||
)
|
||||
|
||||
# Empty content is valid, should return 0 tokens
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert data["token_count"] == 0
|
||||
else:
|
||||
# Or it might be rejected as invalid
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestMCPWorkflowIntegration:
|
||||
"""Test complete MCP workflows end-to-end."""
|
||||
|
||||
async def test_discovery_to_budget_workflow(self, e2e_client):
|
||||
"""Test complete workflow: discover servers -> check budget -> ready for use."""
|
||||
# 1. Discover available servers
|
||||
servers_resp = await e2e_client.get("/api/v1/mcp/servers")
|
||||
assert servers_resp.status_code == 200
|
||||
servers = servers_resp.json()["servers"]
|
||||
assert len(servers) > 0
|
||||
|
||||
# 2. Check context engine health
|
||||
health_resp = await e2e_client.get("/api/v1/context/health")
|
||||
assert health_resp.status_code == 200
|
||||
health = health_resp.json()
|
||||
assert health["status"] == "healthy"
|
||||
|
||||
# 3. Get token budget for a model
|
||||
budget_resp = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
|
||||
assert budget_resp.status_code == 200
|
||||
budget = budget_resp.json()
|
||||
|
||||
# 4. Verify system is ready for context assembly
|
||||
assert budget["total_tokens"] > 0
|
||||
assert health["mcp_connected"] is True
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Requires MCP servers (llm-gateway, knowledge-base) to be running"
|
||||
)
|
||||
async def test_full_context_assembly_workflow(self, e2e_client, e2e_superuser):
|
||||
"""Test complete context assembly workflow with superuser."""
|
||||
project_id = f"e2e-project-{uuid4().hex[:8]}"
|
||||
agent_id = f"e2e-agent-{uuid4().hex[:8]}"
|
||||
|
||||
# 1. Check budget before assembly
|
||||
budget_resp = await e2e_client.get("/api/v1/context/budget/claude-3-sonnet")
|
||||
assert budget_resp.status_code == 200
|
||||
_ = budget_resp.json() # Verify valid response
|
||||
|
||||
# 2. Count tokens in sample content
|
||||
count_resp = await e2e_client.post(
|
||||
"/api/v1/context/count-tokens",
|
||||
json={"content": "This is a test message for context assembly."},
|
||||
)
|
||||
assert count_resp.status_code == 200
|
||||
token_count = count_resp.json()["token_count"]
|
||||
assert token_count > 0
|
||||
|
||||
# 3. Assemble context
|
||||
assemble_resp = await e2e_client.post(
|
||||
"/api/v1/context/assemble",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"query": "Summarize the current project status",
|
||||
"model": "claude-3-sonnet",
|
||||
"system_prompt": "You are a project management assistant.",
|
||||
"task_description": "Generate a status report",
|
||||
"conversation_history": [
|
||||
{"role": "user", "content": "What's the project status?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check the current status.",
|
||||
},
|
||||
],
|
||||
"compress": True,
|
||||
"use_cache": False,
|
||||
},
|
||||
)
|
||||
assert assemble_resp.status_code == 200
|
||||
assembled = assemble_resp.json()
|
||||
|
||||
# 4. Verify assembly results
|
||||
assert assembled["total_tokens"] > 0
|
||||
assert assembled["context_count"] > 0
|
||||
assert assembled["budget_used_percent"] > 0
|
||||
assert assembled["budget_used_percent"] <= 100
|
||||
|
||||
# 5. Get stats to verify the operation was recorded
|
||||
stats_resp = await e2e_client.get(
|
||||
"/api/v1/context/stats",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
assert stats_resp.status_code == 200
|
||||
684
backend/tests/e2e/test_project_workflows.py
Normal file
684
backend/tests/e2e/test_project_workflows.py
Normal file
@@ -0,0 +1,684 @@
|
||||
"""
|
||||
Project and Agent E2E Workflow Tests.
|
||||
|
||||
Tests complete project management workflows with real PostgreSQL:
|
||||
- Project CRUD and lifecycle management
|
||||
- Agent spawning and lifecycle
|
||||
- Issue management within projects
|
||||
- Sprint planning and execution
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from datetime import date, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
class TestProjectCRUDWorkflows:
|
||||
"""Test complete project CRUD workflows."""
|
||||
|
||||
async def test_create_project_workflow(self, e2e_client):
|
||||
"""Test creating a project as authenticated user."""
|
||||
# Register and login
|
||||
email = f"project-owner-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Project",
|
||||
"last_name": "Owner",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Create project
|
||||
project_slug = f"test-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"name": "E2E Test Project",
|
||||
"slug": project_slug,
|
||||
"description": "A project for E2E testing",
|
||||
"autonomy_level": "milestone",
|
||||
},
|
||||
)
|
||||
|
||||
assert create_resp.status_code == 201, f"Failed: {create_resp.text}"
|
||||
project = create_resp.json()
|
||||
assert project["name"] == "E2E Test Project"
|
||||
assert project["slug"] == project_slug
|
||||
assert project["status"] == "active"
|
||||
assert project["agent_count"] == 0
|
||||
assert project["issue_count"] == 0
|
||||
|
||||
async def test_list_projects_only_shows_owned(self, e2e_client):
|
||||
"""Test that users only see their own projects."""
|
||||
# Create two users with projects
|
||||
users = []
|
||||
for i in range(2):
|
||||
email = f"user-{i}-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": f"User{i}",
|
||||
"last_name": "Test",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Each user creates their own project
|
||||
project_slug = f"user{i}-project-{uuid4().hex[:8]}"
|
||||
await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"name": f"User {i} Project",
|
||||
"slug": project_slug,
|
||||
},
|
||||
)
|
||||
users.append({"email": email, "tokens": tokens, "slug": project_slug})
|
||||
|
||||
# User 0 should only see their project
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {users[0]['tokens']['access_token']}"},
|
||||
)
|
||||
assert list_resp.status_code == 200
|
||||
data = list_resp.json()
|
||||
slugs = [p["slug"] for p in data["data"]]
|
||||
assert users[0]["slug"] in slugs
|
||||
assert users[1]["slug"] not in slugs
|
||||
|
||||
async def test_project_lifecycle_pause_resume(self, e2e_client):
|
||||
"""Test pausing and resuming a project."""
|
||||
# Setup user and project
|
||||
email = f"lifecycle-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Lifecycle",
|
||||
"last_name": "Test",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"lifecycle-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Lifecycle Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Pause the project
|
||||
pause_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/pause",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert pause_resp.status_code == 200
|
||||
assert pause_resp.json()["status"] == "paused"
|
||||
|
||||
# Resume the project
|
||||
resume_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/resume",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert resume_resp.status_code == 200
|
||||
assert resume_resp.json()["status"] == "active"
|
||||
|
||||
async def test_project_archive(self, e2e_client):
|
||||
"""Test archiving a project (soft delete)."""
|
||||
# Setup user and project
|
||||
email = f"archive-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Archive",
|
||||
"last_name": "Test",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"archive-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Archive Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Archive the project
|
||||
archive_resp = await e2e_client.delete(
|
||||
f"/api/v1/projects/{project_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert archive_resp.status_code == 200
|
||||
assert archive_resp.json()["success"] is True
|
||||
|
||||
# Verify project is archived
|
||||
get_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert get_resp.status_code == 200
|
||||
assert get_resp.json()["status"] == "archived"
|
||||
|
||||
|
||||
class TestIssueWorkflows:
|
||||
"""Test issue management workflows within projects."""
|
||||
|
||||
async def test_create_and_list_issues(self, e2e_client):
|
||||
"""Test creating and listing issues in a project."""
|
||||
# Setup user and project
|
||||
email = f"issue-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Issue",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"issue-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Issue Test Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create multiple issues
|
||||
issues = []
|
||||
for i in range(3):
|
||||
issue_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": f"Test Issue {i + 1}",
|
||||
"body": f"Description for issue {i + 1}",
|
||||
"priority": ["low", "medium", "high"][i],
|
||||
},
|
||||
)
|
||||
assert issue_resp.status_code == 201, f"Failed: {issue_resp.text}"
|
||||
issues.append(issue_resp.json())
|
||||
|
||||
# List issues
|
||||
list_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert list_resp.status_code == 200
|
||||
data = list_resp.json()
|
||||
assert data["pagination"]["total"] == 3
|
||||
|
||||
async def test_issue_status_transitions(self, e2e_client):
|
||||
"""Test issue status workflow transitions."""
|
||||
# Setup user and project
|
||||
email = f"status-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Status",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"status-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Status Test Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create issue
|
||||
issue_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": "Status Workflow Issue",
|
||||
"body": "Testing status transitions",
|
||||
},
|
||||
)
|
||||
issue = issue_resp.json()
|
||||
issue_id = issue["id"]
|
||||
assert issue["status"] == "open"
|
||||
|
||||
# Transition through statuses
|
||||
for new_status in ["in_progress", "in_review", "closed"]:
|
||||
update_resp = await e2e_client.patch(
|
||||
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"status": new_status},
|
||||
)
|
||||
assert update_resp.status_code == 200, f"Failed: {update_resp.text}"
|
||||
assert update_resp.json()["status"] == new_status
|
||||
|
||||
async def test_issue_filtering(self, e2e_client):
|
||||
"""Test issue filtering by status and priority."""
|
||||
# Setup user and project
|
||||
email = f"filter-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Filter",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"filter-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Filter Test Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create issues with different priorities
|
||||
for priority in ["low", "medium", "high"]:
|
||||
await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": f"{priority.title()} Priority Issue",
|
||||
"priority": priority,
|
||||
},
|
||||
)
|
||||
|
||||
# Filter by high priority
|
||||
filter_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
params={"priority": "high"},
|
||||
)
|
||||
assert filter_resp.status_code == 200
|
||||
data = filter_resp.json()
|
||||
assert data["pagination"]["total"] == 1
|
||||
assert data["data"][0]["priority"] == "high"
|
||||
|
||||
|
||||
class TestSprintWorkflows:
|
||||
"""Test sprint planning and execution workflows."""
|
||||
|
||||
async def test_sprint_lifecycle(self, e2e_client):
|
||||
"""Test complete sprint lifecycle: plan -> start -> complete."""
|
||||
# Setup user and project
|
||||
email = f"sprint-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Sprint",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"sprint-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Sprint Test Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create sprint
|
||||
today = date.today()
|
||||
sprint_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"name": "Sprint 1",
|
||||
"number": 1,
|
||||
"goal": "Complete initial features",
|
||||
"start_date": today.isoformat(),
|
||||
"end_date": (today + timedelta(days=14)).isoformat(),
|
||||
},
|
||||
)
|
||||
assert sprint_resp.status_code == 201, f"Failed: {sprint_resp.text}"
|
||||
sprint = sprint_resp.json()
|
||||
sprint_id = sprint["id"]
|
||||
assert sprint["status"] == "planned"
|
||||
|
||||
# Start sprint
|
||||
start_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/start",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert start_resp.status_code == 200, f"Failed: {start_resp.text}"
|
||||
assert start_resp.json()["status"] == "active"
|
||||
|
||||
# Complete sprint
|
||||
complete_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/complete",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert complete_resp.status_code == 200, f"Failed: {complete_resp.text}"
|
||||
assert complete_resp.json()["status"] == "completed"
|
||||
|
||||
async def test_add_issues_to_sprint(self, e2e_client):
|
||||
"""Test adding issues to a sprint."""
|
||||
# Setup user and project
|
||||
email = f"sprint-issues-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "SprintIssues",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
project_slug = f"sprint-issues-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Sprint Issues Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create sprint
|
||||
today = date.today()
|
||||
sprint_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"name": "Sprint 1",
|
||||
"number": 1,
|
||||
"start_date": today.isoformat(),
|
||||
"end_date": (today + timedelta(days=14)).isoformat(),
|
||||
},
|
||||
)
|
||||
assert sprint_resp.status_code == 201, f"Failed: {sprint_resp.text}"
|
||||
sprint = sprint_resp.json()
|
||||
sprint_id = sprint["id"]
|
||||
|
||||
# Create issue
|
||||
issue_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": "Sprint Issue",
|
||||
"story_points": 5,
|
||||
},
|
||||
)
|
||||
issue = issue_resp.json()
|
||||
issue_id = issue["id"]
|
||||
|
||||
# Add issue to sprint
|
||||
add_resp = await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/sprints/{sprint_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
params={"issue_id": issue_id},
|
||||
)
|
||||
assert add_resp.status_code == 200, f"Failed: {add_resp.text}"
|
||||
|
||||
# Verify issue is in sprint
|
||||
issue_check = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert issue_check.json()["sprint_id"] == sprint_id
|
||||
|
||||
|
||||
class TestCrossEntityValidation:
|
||||
"""Test validation across related entities."""
|
||||
|
||||
async def test_cannot_access_other_users_project(self, e2e_client):
|
||||
"""Test that users cannot access projects they don't own."""
|
||||
# Create two users
|
||||
owner_email = f"owner-{uuid4().hex[:8]}@example.com"
|
||||
other_email = f"other-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
# Register owner
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": owner_email,
|
||||
"password": password,
|
||||
"first_name": "Owner",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
owner_tokens = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": owner_email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
# Register other user
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": other_email,
|
||||
"password": password,
|
||||
"first_name": "Other",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
other_tokens = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": other_email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
# Owner creates project
|
||||
project_slug = f"private-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {owner_tokens['access_token']}"},
|
||||
json={"name": "Private Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Other user tries to access
|
||||
access_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}",
|
||||
headers={"Authorization": f"Bearer {other_tokens['access_token']}"},
|
||||
)
|
||||
assert access_resp.status_code == 403
|
||||
|
||||
async def test_duplicate_project_slug_rejected(self, e2e_client):
|
||||
"""Test that duplicate project slugs are rejected."""
|
||||
email = f"dup-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Dup",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
tokens = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
slug = f"unique-slug-{uuid4().hex[:8]}"
|
||||
|
||||
# First creation should succeed
|
||||
resp1 = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "First Project", "slug": slug},
|
||||
)
|
||||
assert resp1.status_code == 201
|
||||
|
||||
# Second creation with same slug should fail
|
||||
resp2 = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Second Project", "slug": slug},
|
||||
)
|
||||
assert resp2.status_code == 409 # Conflict
|
||||
|
||||
|
||||
class TestIssueStats:
|
||||
"""Test issue statistics endpoints."""
|
||||
|
||||
async def test_issue_stats_aggregation(self, e2e_client):
|
||||
"""Test that issue stats are correctly aggregated."""
|
||||
email = f"stats-test-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Stats",
|
||||
"last_name": "Tester",
|
||||
},
|
||||
)
|
||||
tokens = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
project_slug = f"stats-project-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/projects",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Stats Project", "slug": project_slug},
|
||||
)
|
||||
project = create_resp.json()
|
||||
project_id = project["id"]
|
||||
|
||||
# Create issues with different priorities and story points
|
||||
await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": "High Priority",
|
||||
"priority": "high",
|
||||
"story_points": 8,
|
||||
},
|
||||
)
|
||||
await e2e_client.post(
|
||||
f"/api/v1/projects/{project_id}/issues",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"title": "Low Priority",
|
||||
"priority": "low",
|
||||
"story_points": 2,
|
||||
},
|
||||
)
|
||||
|
||||
# Get stats
|
||||
stats_resp = await e2e_client.get(
|
||||
f"/api/v1/projects/{project_id}/issues/stats",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert stats_resp.status_code == 200
|
||||
stats = stats_resp.json()
|
||||
assert stats["total"] == 2
|
||||
assert stats["total_story_points"] == 10
|
||||
1
backend/tests/integration/__init__.py
Normal file
1
backend/tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Integration tests that require the full stack to be running."""
|
||||
322
backend/tests/integration/test_mcp_integration.py
Normal file
322
backend/tests/integration/test_mcp_integration.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Integration tests for MCP server connectivity.
|
||||
|
||||
These tests require the full stack to be running:
|
||||
- docker compose -f docker-compose.dev.yml up
|
||||
|
||||
Run with:
|
||||
pytest tests/integration/ -v --integration
|
||||
|
||||
Or skip with:
|
||||
pytest tests/ -v --ignore=tests/integration/
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Skip all tests in this module if not running integration tests
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true",
|
||||
reason="Integration tests require RUN_INTEGRATION_TESTS=true and running stack",
|
||||
)
|
||||
|
||||
|
||||
# Configuration from environment
|
||||
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8001")
|
||||
KNOWLEDGE_BASE_URL = os.getenv("KNOWLEDGE_BASE_URL", "http://localhost:8002")
|
||||
|
||||
|
||||
class TestMCPServerHealth:
|
||||
"""Test that MCP servers are healthy and reachable."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_gateway_health(self) -> None:
|
||||
"""Test LLM Gateway health endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{LLM_GATEWAY_URL}/health", timeout=10.0)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("status") == "healthy" or data.get("healthy") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_base_health(self) -> None:
|
||||
"""Test Knowledge Base health endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{KNOWLEDGE_BASE_URL}/health", timeout=10.0)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("status") == "healthy" or data.get("healthy") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backend_health(self) -> None:
|
||||
"""Test Backend health endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{BACKEND_URL}/health", timeout=10.0)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestMCPClientManagerIntegration:
|
||||
"""Test MCPClientManager can connect to real MCP servers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_servers_list(self) -> None:
|
||||
"""Test that backend can list MCP servers via API."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# This endpoint lists configured MCP servers
|
||||
response = await client.get(
|
||||
f"{BACKEND_URL}/api/v1/mcp/servers",
|
||||
timeout=10.0,
|
||||
)
|
||||
# Should return 200 or 401 (if auth required)
|
||||
assert response.status_code in [200, 401, 403]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_health_check_endpoint(self) -> None:
|
||||
"""Test backend's MCP health check endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{BACKEND_URL}/api/v1/mcp/health",
|
||||
timeout=30.0, # MCP health checks can take time
|
||||
)
|
||||
# Should return 200 or 401 (if auth required)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# Check structure
|
||||
assert "servers" in data or "healthy" in data
|
||||
|
||||
|
||||
class TestLLMGatewayIntegration:
|
||||
"""Test LLM Gateway MCP server functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models(self) -> None:
|
||||
"""Test that LLM Gateway can list available models."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# MCP servers use JSON-RPC 2.0 protocol at /mcp endpoint
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": {},
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should have tools listed
|
||||
assert "result" in data or "error" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens(self) -> None:
|
||||
"""Test token counting functionality."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "count_tokens",
|
||||
"arguments": {
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"text": "Hello, world!",
|
||||
},
|
||||
},
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Check for result or error
|
||||
if "result" in data:
|
||||
assert "content" in data["result"] or "token_count" in str(
|
||||
data["result"]
|
||||
)
|
||||
|
||||
|
||||
class TestKnowledgeBaseIntegration:
|
||||
"""Test Knowledge Base MCP server functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(self) -> None:
|
||||
"""Test that Knowledge Base can list available tools."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Knowledge Base uses GET /mcp/tools for listing
|
||||
response = await client.get(
|
||||
f"{KNOWLEDGE_BASE_URL}/mcp/tools",
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_knowledge_empty(self) -> None:
|
||||
"""Test search on empty knowledge base."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Knowledge Base uses direct tool name as method
|
||||
response = await client.post(
|
||||
f"{KNOWLEDGE_BASE_URL}/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "search_knowledge",
|
||||
"params": {
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test query",
|
||||
"limit": 5,
|
||||
},
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should return empty results or error for no collection
|
||||
assert "result" in data or "error" in data
|
||||
|
||||
|
||||
class TestEndToEndMCPFlow:
|
||||
"""End-to-end tests for MCP integration flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_mcp_discovery_flow(self) -> None:
|
||||
"""Test the full flow of discovering and listing MCP tools."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# 1. Check backend health
|
||||
health = await client.get(f"{BACKEND_URL}/health", timeout=10.0)
|
||||
assert health.status_code == 200
|
||||
|
||||
# 2. Check LLM Gateway health
|
||||
llm_health = await client.get(f"{LLM_GATEWAY_URL}/health", timeout=10.0)
|
||||
assert llm_health.status_code == 200
|
||||
|
||||
# 3. Check Knowledge Base health
|
||||
kb_health = await client.get(f"{KNOWLEDGE_BASE_URL}/health", timeout=10.0)
|
||||
assert kb_health.status_code == 200
|
||||
|
||||
# 4. List tools from LLM Gateway (uses JSON-RPC at /mcp)
|
||||
llm_tools = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/mcp",
|
||||
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}},
|
||||
timeout=10.0,
|
||||
)
|
||||
assert llm_tools.status_code == 200
|
||||
|
||||
# 5. List tools from Knowledge Base (uses GET /mcp/tools)
|
||||
kb_tools = await client.get(
|
||||
f"{KNOWLEDGE_BASE_URL}/mcp/tools",
|
||||
timeout=10.0,
|
||||
)
|
||||
assert kb_tools.status_code == 200
|
||||
|
||||
|
||||
class TestContextEngineIntegration:
|
||||
"""Test Context Engine integration with MCP servers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_health_endpoint(self) -> None:
|
||||
"""Test context engine health endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{BACKEND_URL}/api/v1/context/health",
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("status") == "healthy"
|
||||
assert "mcp_connected" in data
|
||||
assert "cache_enabled" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_budget_endpoint(self) -> None:
|
||||
"""Test token budget endpoint."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{BACKEND_URL}/api/v1/context/budget/claude-3-sonnet",
|
||||
timeout=10.0,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_tokens" in data
|
||||
assert "system_tokens" in data
|
||||
assert data.get("model") == "claude-3-sonnet"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_assembly_requires_auth(self) -> None:
|
||||
"""Test that context assembly requires authentication."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{BACKEND_URL}/api/v1/context/assemble",
|
||||
json={
|
||||
"project_id": "test-project",
|
||||
"agent_id": "test-agent",
|
||||
"query": "test query",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
# Should require auth
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
|
||||
def run_quick_health_check() -> dict[str, Any]:
|
||||
"""
|
||||
Quick synchronous health check for all services.
|
||||
Can be run standalone to verify the stack is up.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
results: dict[str, Any] = {
|
||||
"backend": False,
|
||||
"llm_gateway": False,
|
||||
"knowledge_base": False,
|
||||
}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=5.0) as client:
|
||||
try:
|
||||
r = client.get(f"{BACKEND_URL}/health")
|
||||
results["backend"] = r.status_code == 200
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
r = client.get(f"{LLM_GATEWAY_URL}/health")
|
||||
results["llm_gateway"] = r.status_code == 200
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
r = client.get(f"{KNOWLEDGE_BASE_URL}/health")
|
||||
results["knowledge_base"] = r.status_code == 200
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Checking service health...")
|
||||
results = run_quick_health_check()
|
||||
for service, healthy in results.items():
|
||||
status = "OK" if healthy else "FAILED"
|
||||
print(f" {service}: {status}")
|
||||
|
||||
all_healthy = all(results.values())
|
||||
if all_healthy:
|
||||
print("\nAll services healthy! Run integration tests with:")
|
||||
print(" RUN_INTEGRATION_TESTS=true pytest tests/integration/ -v")
|
||||
else:
|
||||
print("\nSome services are not healthy. Start the stack with:")
|
||||
print(" make dev")
|
||||
1
backend/tests/services/context/__init__.py
Normal file
1
backend/tests/services/context/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for Context Management Engine."""
|
||||
518
backend/tests/services/context/test_adapters.py
Normal file
518
backend/tests/services/context/test_adapters.py
Normal file
@@ -0,0 +1,518 @@
|
||||
"""Tests for model adapters."""
|
||||
|
||||
from app.services.context.adapters import (
|
||||
ClaudeAdapter,
|
||||
DefaultAdapter,
|
||||
OpenAIAdapter,
|
||||
get_adapter,
|
||||
)
|
||||
from app.services.context.types import (
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
|
||||
class TestGetAdapter:
|
||||
"""Tests for get_adapter function."""
|
||||
|
||||
def test_claude_models(self) -> None:
|
||||
"""Test that Claude models get ClaudeAdapter."""
|
||||
assert isinstance(get_adapter("claude-3-sonnet"), ClaudeAdapter)
|
||||
assert isinstance(get_adapter("claude-3-opus"), ClaudeAdapter)
|
||||
assert isinstance(get_adapter("claude-3-haiku"), ClaudeAdapter)
|
||||
assert isinstance(get_adapter("claude-2"), ClaudeAdapter)
|
||||
assert isinstance(get_adapter("anthropic/claude-3-sonnet"), ClaudeAdapter)
|
||||
|
||||
def test_openai_models(self) -> None:
|
||||
"""Test that OpenAI models get OpenAIAdapter."""
|
||||
assert isinstance(get_adapter("gpt-4"), OpenAIAdapter)
|
||||
assert isinstance(get_adapter("gpt-4-turbo"), OpenAIAdapter)
|
||||
assert isinstance(get_adapter("gpt-3.5-turbo"), OpenAIAdapter)
|
||||
assert isinstance(get_adapter("openai/gpt-4"), OpenAIAdapter)
|
||||
assert isinstance(get_adapter("o1-mini"), OpenAIAdapter)
|
||||
assert isinstance(get_adapter("o3-mini"), OpenAIAdapter)
|
||||
|
||||
def test_unknown_models(self) -> None:
|
||||
"""Test that unknown models get DefaultAdapter."""
|
||||
assert isinstance(get_adapter("llama-2"), DefaultAdapter)
|
||||
assert isinstance(get_adapter("mistral-7b"), DefaultAdapter)
|
||||
assert isinstance(get_adapter("custom-model"), DefaultAdapter)
|
||||
|
||||
|
||||
class TestModelAdapterBase:
|
||||
"""Tests for ModelAdapter base class."""
|
||||
|
||||
def test_get_type_order(self) -> None:
|
||||
"""Test default type ordering."""
|
||||
adapter = DefaultAdapter()
|
||||
order = adapter.get_type_order()
|
||||
|
||||
assert order == [
|
||||
ContextType.SYSTEM,
|
||||
ContextType.TASK,
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.CONVERSATION,
|
||||
ContextType.TOOL,
|
||||
]
|
||||
|
||||
def test_group_by_type(self) -> None:
|
||||
"""Test grouping contexts by type."""
|
||||
adapter = DefaultAdapter()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System", source="system"),
|
||||
TaskContext(content="Task", source="task"),
|
||||
KnowledgeContext(content="Knowledge", source="docs"),
|
||||
SystemContext(content="System 2", source="system"),
|
||||
]
|
||||
|
||||
grouped = adapter.group_by_type(contexts)
|
||||
|
||||
assert len(grouped[ContextType.SYSTEM]) == 2
|
||||
assert len(grouped[ContextType.TASK]) == 1
|
||||
assert len(grouped[ContextType.KNOWLEDGE]) == 1
|
||||
assert ContextType.CONVERSATION not in grouped
|
||||
|
||||
def test_matches_model_default(self) -> None:
|
||||
"""Test that DefaultAdapter matches all models."""
|
||||
assert DefaultAdapter.matches_model("anything")
|
||||
assert DefaultAdapter.matches_model("claude-3")
|
||||
assert DefaultAdapter.matches_model("gpt-4")
|
||||
|
||||
|
||||
class TestDefaultAdapter:
|
||||
"""Tests for DefaultAdapter."""
|
||||
|
||||
def test_format_empty(self) -> None:
|
||||
"""Test formatting empty context list."""
|
||||
adapter = DefaultAdapter()
|
||||
result = adapter.format([])
|
||||
assert result == ""
|
||||
|
||||
def test_format_system(self) -> None:
|
||||
"""Test formatting system context."""
|
||||
adapter = DefaultAdapter()
|
||||
contexts = [
|
||||
SystemContext(content="You are helpful.", source="system"),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "You are helpful." in result
|
||||
|
||||
def test_format_task(self) -> None:
|
||||
"""Test formatting task context."""
|
||||
adapter = DefaultAdapter()
|
||||
contexts = [
|
||||
TaskContext(content="Write a function.", source="task"),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "Task:" in result
|
||||
assert "Write a function." in result
|
||||
|
||||
def test_format_knowledge(self) -> None:
|
||||
"""Test formatting knowledge context."""
|
||||
adapter = DefaultAdapter()
|
||||
contexts = [
|
||||
KnowledgeContext(content="Documentation here.", source="docs"),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "Reference Information:" in result
|
||||
assert "Documentation here." in result
|
||||
|
||||
def test_format_conversation(self) -> None:
|
||||
"""Test formatting conversation context."""
|
||||
adapter = DefaultAdapter()
|
||||
contexts = [
|
||||
ConversationContext(
|
||||
content="Hello!",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "Previous Conversation:" in result
|
||||
assert "Hello!" in result
|
||||
|
||||
def test_format_tool(self) -> None:
|
||||
"""Test formatting tool context."""
|
||||
adapter = DefaultAdapter()
|
||||
contexts = [
|
||||
ToolContext(
|
||||
content="Result: success",
|
||||
source="tool",
|
||||
metadata={"tool_name": "search"},
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "Tool Results:" in result
|
||||
assert "Result: success" in result
|
||||
|
||||
|
||||
class TestClaudeAdapter:
|
||||
"""Tests for ClaudeAdapter."""
|
||||
|
||||
def test_matches_model(self) -> None:
|
||||
"""Test model matching."""
|
||||
assert ClaudeAdapter.matches_model("claude-3-sonnet")
|
||||
assert ClaudeAdapter.matches_model("claude-3-opus")
|
||||
assert ClaudeAdapter.matches_model("anthropic/claude-3-haiku")
|
||||
assert not ClaudeAdapter.matches_model("gpt-4")
|
||||
assert not ClaudeAdapter.matches_model("llama-2")
|
||||
|
||||
def test_format_empty(self) -> None:
|
||||
"""Test formatting empty context list."""
|
||||
adapter = ClaudeAdapter()
|
||||
result = adapter.format([])
|
||||
assert result == ""
|
||||
|
||||
def test_format_system_uses_xml(self) -> None:
|
||||
"""Test that system context uses XML tags."""
|
||||
adapter = ClaudeAdapter()
|
||||
contexts = [
|
||||
SystemContext(content="You are helpful.", source="system"),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "<system_instructions>" in result
|
||||
assert "</system_instructions>" in result
|
||||
assert "You are helpful." in result
|
||||
|
||||
def test_format_task_uses_xml(self) -> None:
|
||||
"""Test that task context uses XML tags."""
|
||||
adapter = ClaudeAdapter()
|
||||
contexts = [
|
||||
TaskContext(content="Write a function.", source="task"),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "<current_task>" in result
|
||||
assert "</current_task>" in result
|
||||
assert "Write a function." in result
|
||||
|
||||
def test_format_knowledge_uses_document_tags(self) -> None:
|
||||
"""Test that knowledge uses document XML tags."""
|
||||
adapter = ClaudeAdapter()
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Documentation here.",
|
||||
source="docs/api.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "<reference_documents>" in result
|
||||
assert "</reference_documents>" in result
|
||||
assert '<document source="docs/api.md"' in result
|
||||
assert "</document>" in result
|
||||
assert "Documentation here." in result
|
||||
|
||||
def test_format_knowledge_with_score(self) -> None:
|
||||
"""Test that knowledge includes relevance score."""
|
||||
adapter = ClaudeAdapter()
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Doc content.",
|
||||
source="docs/api.md",
|
||||
metadata={"relevance_score": 0.95},
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert 'relevance="0.95"' in result
|
||||
|
||||
def test_format_conversation_uses_message_tags(self) -> None:
|
||||
"""Test that conversation uses message XML tags."""
|
||||
adapter = ClaudeAdapter()
|
||||
contexts = [
|
||||
ConversationContext(
|
||||
content="Hello!",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
ConversationContext(
|
||||
content="Hi there!",
|
||||
source="chat",
|
||||
role=MessageRole.ASSISTANT,
|
||||
metadata={"role": "assistant"},
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "<conversation_history>" in result
|
||||
assert "</conversation_history>" in result
|
||||
assert '<message role="user">' in result
|
||||
assert '<message role="assistant">' in result
|
||||
assert "Hello!" in result
|
||||
assert "Hi there!" in result
|
||||
|
||||
def test_format_tool_uses_tool_result_tags(self) -> None:
|
||||
"""Test that tool results use tool_result XML tags."""
|
||||
adapter = ClaudeAdapter()
|
||||
contexts = [
|
||||
ToolContext(
|
||||
content='{"status": "ok"}',
|
||||
source="tool",
|
||||
metadata={"tool_name": "search", "status": "success"},
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "<tool_results>" in result
|
||||
assert "</tool_results>" in result
|
||||
assert '<tool_result name="search"' in result
|
||||
assert 'status="success"' in result
|
||||
assert "</tool_result>" in result
|
||||
|
||||
def test_format_multiple_types_in_order(self) -> None:
|
||||
"""Test that multiple types are formatted in correct order."""
|
||||
adapter = ClaudeAdapter()
|
||||
contexts = [
|
||||
KnowledgeContext(content="Knowledge", source="docs"),
|
||||
SystemContext(content="System", source="system"),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
|
||||
# Find positions
|
||||
system_pos = result.find("<system_instructions>")
|
||||
task_pos = result.find("<current_task>")
|
||||
knowledge_pos = result.find("<reference_documents>")
|
||||
|
||||
# Verify order
|
||||
assert system_pos < task_pos < knowledge_pos
|
||||
|
||||
def test_escape_xml_in_source(self) -> None:
|
||||
"""Test that XML special chars are escaped in source."""
|
||||
adapter = ClaudeAdapter()
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Doc content.",
|
||||
source='path/with"quotes&stuff.md',
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert """ in result
|
||||
assert "&" in result
|
||||
|
||||
|
||||
class TestOpenAIAdapter:
|
||||
"""Tests for OpenAIAdapter."""
|
||||
|
||||
def test_matches_model(self) -> None:
|
||||
"""Test model matching."""
|
||||
assert OpenAIAdapter.matches_model("gpt-4")
|
||||
assert OpenAIAdapter.matches_model("gpt-4-turbo")
|
||||
assert OpenAIAdapter.matches_model("gpt-3.5-turbo")
|
||||
assert OpenAIAdapter.matches_model("openai/gpt-4")
|
||||
assert OpenAIAdapter.matches_model("o1-preview")
|
||||
assert OpenAIAdapter.matches_model("o3-mini")
|
||||
assert not OpenAIAdapter.matches_model("claude-3")
|
||||
assert not OpenAIAdapter.matches_model("llama-2")
|
||||
|
||||
def test_format_empty(self) -> None:
|
||||
"""Test formatting empty context list."""
|
||||
adapter = OpenAIAdapter()
|
||||
result = adapter.format([])
|
||||
assert result == ""
|
||||
|
||||
def test_format_system_plain(self) -> None:
|
||||
"""Test that system content is plain."""
|
||||
adapter = OpenAIAdapter()
|
||||
contexts = [
|
||||
SystemContext(content="You are helpful.", source="system"),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
# System content should be plain without headers
|
||||
assert "You are helpful." in result
|
||||
assert "##" not in result # No markdown headers for system
|
||||
|
||||
def test_format_task_uses_markdown(self) -> None:
|
||||
"""Test that task uses markdown headers."""
|
||||
adapter = OpenAIAdapter()
|
||||
contexts = [
|
||||
TaskContext(content="Write a function.", source="task"),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "## Current Task" in result
|
||||
assert "Write a function." in result
|
||||
|
||||
def test_format_knowledge_uses_markdown(self) -> None:
|
||||
"""Test that knowledge uses markdown with source headers."""
|
||||
adapter = OpenAIAdapter()
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Documentation here.",
|
||||
source="docs/api.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "## Reference Documents" in result
|
||||
assert "### Source: docs/api.md" in result
|
||||
assert "Documentation here." in result
|
||||
|
||||
def test_format_knowledge_with_score(self) -> None:
|
||||
"""Test that knowledge includes relevance score."""
|
||||
adapter = OpenAIAdapter()
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Doc content.",
|
||||
source="docs/api.md",
|
||||
metadata={"relevance_score": 0.95},
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "(relevance: 0.95)" in result
|
||||
|
||||
def test_format_conversation_uses_bold_roles(self) -> None:
|
||||
"""Test that conversation uses bold role labels."""
|
||||
adapter = OpenAIAdapter()
|
||||
contexts = [
|
||||
ConversationContext(
|
||||
content="Hello!",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
ConversationContext(
|
||||
content="Hi there!",
|
||||
source="chat",
|
||||
role=MessageRole.ASSISTANT,
|
||||
metadata={"role": "assistant"},
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "**USER**:" in result
|
||||
assert "**ASSISTANT**:" in result
|
||||
assert "Hello!" in result
|
||||
assert "Hi there!" in result
|
||||
|
||||
def test_format_tool_uses_code_blocks(self) -> None:
|
||||
"""Test that tool results use code blocks."""
|
||||
adapter = OpenAIAdapter()
|
||||
contexts = [
|
||||
ToolContext(
|
||||
content='{"status": "ok"}',
|
||||
source="tool",
|
||||
metadata={"tool_name": "search", "status": "success"},
|
||||
),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
assert "## Recent Tool Results" in result
|
||||
assert "### Tool: search (success)" in result
|
||||
assert "```" in result # Code block
|
||||
assert '{"status": "ok"}' in result
|
||||
|
||||
def test_format_multiple_types_in_order(self) -> None:
|
||||
"""Test that multiple types are formatted in correct order."""
|
||||
adapter = OpenAIAdapter()
|
||||
contexts = [
|
||||
KnowledgeContext(content="Knowledge", source="docs"),
|
||||
SystemContext(content="System", source="system"),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
result = adapter.format(contexts)
|
||||
|
||||
# System comes first (no header), then task, then knowledge
|
||||
system_pos = result.find("System")
|
||||
task_pos = result.find("## Current Task")
|
||||
knowledge_pos = result.find("## Reference Documents")
|
||||
|
||||
assert system_pos < task_pos < knowledge_pos
|
||||
|
||||
|
||||
class TestAdapterIntegration:
|
||||
"""Integration tests for adapters."""
|
||||
|
||||
def test_full_context_formatting_claude(self) -> None:
|
||||
"""Test formatting a full set of contexts for Claude."""
|
||||
adapter = ClaudeAdapter()
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are an expert Python developer.",
|
||||
source="system",
|
||||
),
|
||||
TaskContext(
|
||||
content="Implement user authentication.",
|
||||
source="task:AUTH-123",
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="JWT tokens provide stateless authentication...",
|
||||
source="docs/auth/jwt.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
ConversationContext(
|
||||
content="Can you help me implement JWT auth?",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
ToolContext(
|
||||
content='{"file": "auth.py", "status": "created"}',
|
||||
source="tool",
|
||||
metadata={"tool_name": "file_create"},
|
||||
),
|
||||
]
|
||||
|
||||
result = adapter.format(contexts)
|
||||
|
||||
# Verify all sections present
|
||||
assert "<system_instructions>" in result
|
||||
assert "<current_task>" in result
|
||||
assert "<reference_documents>" in result
|
||||
assert "<conversation_history>" in result
|
||||
assert "<tool_results>" in result
|
||||
|
||||
# Verify content
|
||||
assert "expert Python developer" in result
|
||||
assert "user authentication" in result
|
||||
assert "JWT tokens" in result
|
||||
assert "help me implement" in result
|
||||
assert "file_create" in result
|
||||
|
||||
def test_full_context_formatting_openai(self) -> None:
|
||||
"""Test formatting a full set of contexts for OpenAI."""
|
||||
adapter = OpenAIAdapter()
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are an expert Python developer.",
|
||||
source="system",
|
||||
),
|
||||
TaskContext(
|
||||
content="Implement user authentication.",
|
||||
source="task:AUTH-123",
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="JWT tokens provide stateless authentication...",
|
||||
source="docs/auth/jwt.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
ConversationContext(
|
||||
content="Can you help me implement JWT auth?",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
ToolContext(
|
||||
content='{"file": "auth.py", "status": "created"}',
|
||||
source="tool",
|
||||
metadata={"tool_name": "file_create"},
|
||||
),
|
||||
]
|
||||
|
||||
result = adapter.format(contexts)
|
||||
|
||||
# Verify all sections present
|
||||
assert "## Current Task" in result
|
||||
assert "## Reference Documents" in result
|
||||
assert "## Recent Tool Results" in result
|
||||
assert "**USER**:" in result
|
||||
|
||||
# Verify content
|
||||
assert "expert Python developer" in result
|
||||
assert "user authentication" in result
|
||||
assert "JWT tokens" in result
|
||||
assert "help me implement" in result
|
||||
assert "file_create" in result
|
||||
508
backend/tests/services/context/test_assembly.py
Normal file
508
backend/tests/services/context/test_assembly.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""Tests for context assembly pipeline."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.assembly import ContextPipeline, PipelineMetrics
|
||||
from app.services.context.budget import TokenBudget
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
|
||||
class TestPipelineMetrics:
|
||||
"""Tests for PipelineMetrics dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test metrics creation."""
|
||||
metrics = PipelineMetrics()
|
||||
|
||||
assert metrics.total_contexts == 0
|
||||
assert metrics.selected_contexts == 0
|
||||
assert metrics.assembly_time_ms == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
metrics = PipelineMetrics(
|
||||
total_contexts=10,
|
||||
selected_contexts=8,
|
||||
excluded_contexts=2,
|
||||
total_tokens=500,
|
||||
assembly_time_ms=25.5,
|
||||
)
|
||||
metrics.end_time = datetime.now(UTC)
|
||||
|
||||
data = metrics.to_dict()
|
||||
|
||||
assert data["total_contexts"] == 10
|
||||
assert data["selected_contexts"] == 8
|
||||
assert data["excluded_contexts"] == 2
|
||||
assert data["total_tokens"] == 500
|
||||
assert data["assembly_time_ms"] == 25.5
|
||||
assert "start_time" in data
|
||||
assert "end_time" in data
|
||||
|
||||
|
||||
class TestContextPipeline:
|
||||
"""Tests for ContextPipeline."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test pipeline creation."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
assert pipeline._calculator is not None
|
||||
assert pipeline._scorer is not None
|
||||
assert pipeline._ranker is not None
|
||||
assert pipeline._compressor is not None
|
||||
assert pipeline._allocator is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_empty_contexts(self) -> None:
|
||||
"""Test assembling empty context list."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=[],
|
||||
query="test query",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
assert isinstance(result, AssembledContext)
|
||||
assert result.context_count == 0
|
||||
assert result.total_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_single_context(self) -> None:
|
||||
"""Test assembling single context."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are a helpful assistant.",
|
||||
source="system",
|
||||
)
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="help me",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert result.total_tokens > 0
|
||||
assert "helpful assistant" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_multiple_types(self) -> None:
|
||||
"""Test assembling multiple context types."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are a coding assistant.",
|
||||
source="system",
|
||||
),
|
||||
TaskContext(
|
||||
content="Implement a login feature.",
|
||||
source="task",
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Authentication best practices include...",
|
||||
source="docs/auth.md",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="implement login",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
assert result.context_count >= 1
|
||||
assert result.total_tokens > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_custom_budget(self) -> None:
|
||||
"""Test assembling with custom budget."""
|
||||
pipeline = ContextPipeline()
|
||||
budget = TokenBudget(
|
||||
total=1000,
|
||||
system=200,
|
||||
task=200,
|
||||
knowledge=400,
|
||||
conversation=100,
|
||||
tools=50,
|
||||
response_reserve=50,
|
||||
)
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System prompt", source="system"),
|
||||
TaskContext(content="Task description", source="task"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4",
|
||||
custom_budget=budget,
|
||||
)
|
||||
|
||||
assert result.context_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_max_tokens(self) -> None:
|
||||
"""Test assembling with max_tokens limit."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System prompt", source="system"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4",
|
||||
max_tokens=5000,
|
||||
)
|
||||
|
||||
assert "budget" in result.metadata
|
||||
assert result.metadata["budget"]["total"] == 5000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_format_output(self) -> None:
|
||||
"""Test formatted vs unformatted output."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System prompt", source="system"),
|
||||
]
|
||||
|
||||
# Formatted (default)
|
||||
result_formatted = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
format_output=True,
|
||||
)
|
||||
|
||||
# Unformatted
|
||||
result_raw = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
format_output=False,
|
||||
)
|
||||
|
||||
# Formatted should have XML tags for Claude
|
||||
assert "<system_instructions>" in result_formatted.content
|
||||
# Raw should not
|
||||
assert "<system_instructions>" not in result_raw.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_metrics(self) -> None:
|
||||
"""Test that metrics are populated."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System", source="system"),
|
||||
TaskContext(content="Task", source="task"),
|
||||
KnowledgeContext(
|
||||
content="Knowledge",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
assert "metrics" in result.metadata
|
||||
metrics = result.metadata["metrics"]
|
||||
|
||||
assert metrics["total_contexts"] == 3
|
||||
assert metrics["assembly_time_ms"] > 0
|
||||
assert "scoring_time_ms" in metrics
|
||||
assert "formatting_time_ms" in metrics
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_compression_disabled(self) -> None:
|
||||
"""Test assembling with compression disabled."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(content="A" * 1000, source="docs"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4",
|
||||
compress=False,
|
||||
)
|
||||
|
||||
# Should still work, just no compression applied
|
||||
assert result.context_count >= 0
|
||||
|
||||
|
||||
class TestContextPipelineFormatting:
|
||||
"""Tests for context formatting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_claude_uses_xml(self) -> None:
|
||||
"""Test that Claude models use XML formatting."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System prompt", source="system"),
|
||||
TaskContext(content="Task", source="task"),
|
||||
KnowledgeContext(
|
||||
content="Knowledge",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
# Claude should have XML tags
|
||||
assert "<system_instructions>" in result.content or result.context_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_openai_uses_markdown(self) -> None:
|
||||
"""Test that OpenAI models use markdown formatting."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
TaskContext(content="Task description", source="task"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
# OpenAI should have markdown headers
|
||||
if result.context_count > 0 and "Task" in result.content:
|
||||
assert "## Current Task" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_knowledge_claude(self) -> None:
|
||||
"""Test knowledge formatting for Claude."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Document content here",
|
||||
source="docs/file.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
if result.context_count > 0:
|
||||
assert "<reference_documents>" in result.content
|
||||
assert "<document" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_conversation(self) -> None:
|
||||
"""Test conversation formatting."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
ConversationContext(
|
||||
content="Hello, how are you?",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
ConversationContext(
|
||||
content="I'm doing great!",
|
||||
source="chat",
|
||||
role=MessageRole.ASSISTANT,
|
||||
metadata={"role": "assistant"},
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
if result.context_count > 0:
|
||||
assert "<conversation_history>" in result.content
|
||||
assert (
|
||||
'<message role="user">' in result.content
|
||||
or 'role="user"' in result.content
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_tool_results(self) -> None:
|
||||
"""Test tool result formatting."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
contexts = [
|
||||
ToolContext(
|
||||
content="Tool output here",
|
||||
source="tool",
|
||||
metadata={"tool_name": "search"},
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
if result.context_count > 0:
|
||||
assert "<tool_results>" in result.content
|
||||
|
||||
|
||||
class TestContextPipelineIntegration:
|
||||
"""Integration tests for full pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_pipeline_workflow(self) -> None:
|
||||
"""Test complete pipeline workflow."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
# Create realistic context mix
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are an expert Python developer.",
|
||||
source="system",
|
||||
),
|
||||
TaskContext(
|
||||
content="Implement a user authentication system.",
|
||||
source="task:AUTH-123",
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="JWT tokens provide stateless authentication...",
|
||||
source="docs/auth/jwt.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="OAuth 2.0 is an authorization framework...",
|
||||
source="docs/auth/oauth.md",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
ConversationContext(
|
||||
content="Can you help me implement JWT auth?",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="implement JWT authentication",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, AssembledContext)
|
||||
assert result.context_count > 0
|
||||
assert result.total_tokens > 0
|
||||
assert result.assembly_time_ms > 0
|
||||
assert result.model == "claude-3-sonnet"
|
||||
assert len(result.content) > 0
|
||||
|
||||
# Verify metrics
|
||||
assert "metrics" in result.metadata
|
||||
assert "query" in result.metadata
|
||||
assert "budget" in result.metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_type_ordering(self) -> None:
|
||||
"""Test that contexts are ordered by type correctly."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
# Add in random order
|
||||
contexts = [
|
||||
KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9),
|
||||
ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}),
|
||||
SystemContext(content="System", source="system"),
|
||||
ConversationContext(
|
||||
content="Chat",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
metadata={"role": "user"},
|
||||
),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
)
|
||||
|
||||
# For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool
|
||||
content = result.content
|
||||
if result.context_count > 0:
|
||||
# Find positions (if they exist)
|
||||
system_pos = content.find("system_instructions")
|
||||
task_pos = content.find("current_task")
|
||||
knowledge_pos = content.find("reference_documents")
|
||||
conversation_pos = content.find("conversation_history")
|
||||
tool_pos = content.find("tool_results")
|
||||
|
||||
# Verify ordering (only check if both exist)
|
||||
if system_pos >= 0 and task_pos >= 0:
|
||||
assert system_pos < task_pos
|
||||
if task_pos >= 0 and knowledge_pos >= 0:
|
||||
assert task_pos < knowledge_pos
|
||||
if knowledge_pos >= 0 and conversation_pos >= 0:
|
||||
assert knowledge_pos < conversation_pos
|
||||
if conversation_pos >= 0 and tool_pos >= 0:
|
||||
assert conversation_pos < tool_pos
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excluded_contexts_tracked(self) -> None:
|
||||
"""Test that excluded contexts are tracked in result."""
|
||||
pipeline = ContextPipeline()
|
||||
|
||||
# Create many contexts to force some exclusions
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A" * 500, # Large content
|
||||
source=f"docs/{i}",
|
||||
relevance_score=0.1 + (i * 0.05),
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
result = await pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query="test",
|
||||
model="gpt-4", # Smaller context window
|
||||
max_tokens=1000, # Limited budget
|
||||
)
|
||||
|
||||
# Should have excluded some
|
||||
assert result.excluded_count >= 0
|
||||
assert result.context_count + result.excluded_count <= len(contexts)
|
||||
533
backend/tests/services/context/test_budget.py
Normal file
533
backend/tests/services/context/test_budget.py
Normal file
@@ -0,0 +1,533 @@
|
||||
"""Tests for token budget management."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.budget import (
|
||||
BudgetAllocator,
|
||||
TokenBudget,
|
||||
TokenCalculator,
|
||||
)
|
||||
from app.services.context.config import ContextSettings
|
||||
from app.services.context.exceptions import BudgetExceededError
|
||||
from app.services.context.types import ContextType
|
||||
|
||||
|
||||
class TestTokenBudget:
|
||||
"""Tests for TokenBudget dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic budget creation."""
|
||||
budget = TokenBudget(total=10000)
|
||||
assert budget.total == 10000
|
||||
assert budget.system == 0
|
||||
assert budget.total_used() == 0
|
||||
|
||||
def test_creation_with_allocations(self) -> None:
|
||||
"""Test budget creation with allocations."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
task=1000,
|
||||
knowledge=4000,
|
||||
conversation=2000,
|
||||
tools=500,
|
||||
response_reserve=1500,
|
||||
buffer=500,
|
||||
)
|
||||
|
||||
assert budget.system == 500
|
||||
assert budget.knowledge == 4000
|
||||
assert budget.response_reserve == 1500
|
||||
|
||||
def test_get_allocation(self) -> None:
|
||||
"""Test getting allocation for a type."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
knowledge=4000,
|
||||
)
|
||||
|
||||
assert budget.get_allocation(ContextType.SYSTEM) == 500
|
||||
assert budget.get_allocation(ContextType.KNOWLEDGE) == 4000
|
||||
assert budget.get_allocation("system") == 500
|
||||
|
||||
def test_remaining(self) -> None:
|
||||
"""Test remaining budget calculation."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
knowledge=4000,
|
||||
)
|
||||
|
||||
# Initially full
|
||||
assert budget.remaining(ContextType.SYSTEM) == 500
|
||||
assert budget.remaining(ContextType.KNOWLEDGE) == 4000
|
||||
|
||||
# After allocation
|
||||
budget.allocate(ContextType.SYSTEM, 200)
|
||||
assert budget.remaining(ContextType.SYSTEM) == 300
|
||||
|
||||
def test_can_fit(self) -> None:
|
||||
"""Test can_fit check."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
knowledge=4000,
|
||||
)
|
||||
|
||||
assert budget.can_fit(ContextType.SYSTEM, 500) is True
|
||||
assert budget.can_fit(ContextType.SYSTEM, 501) is False
|
||||
assert budget.can_fit(ContextType.KNOWLEDGE, 4000) is True
|
||||
|
||||
def test_allocate_success(self) -> None:
|
||||
"""Test successful allocation."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
)
|
||||
|
||||
result = budget.allocate(ContextType.SYSTEM, 200)
|
||||
assert result is True
|
||||
assert budget.get_used(ContextType.SYSTEM) == 200
|
||||
assert budget.remaining(ContextType.SYSTEM) == 300
|
||||
|
||||
def test_allocate_exceeds_budget(self) -> None:
|
||||
"""Test allocation exceeding budget."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
)
|
||||
|
||||
with pytest.raises(BudgetExceededError) as exc_info:
|
||||
budget.allocate(ContextType.SYSTEM, 600)
|
||||
|
||||
assert exc_info.value.allocated == 500
|
||||
assert exc_info.value.requested == 600
|
||||
|
||||
def test_allocate_force(self) -> None:
|
||||
"""Test forced allocation exceeding budget."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
)
|
||||
|
||||
# Force should allow exceeding
|
||||
result = budget.allocate(ContextType.SYSTEM, 600, force=True)
|
||||
assert result is True
|
||||
assert budget.get_used(ContextType.SYSTEM) == 600
|
||||
|
||||
def test_deallocate(self) -> None:
|
||||
"""Test deallocation."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
)
|
||||
|
||||
budget.allocate(ContextType.SYSTEM, 300)
|
||||
assert budget.get_used(ContextType.SYSTEM) == 300
|
||||
|
||||
budget.deallocate(ContextType.SYSTEM, 100)
|
||||
assert budget.get_used(ContextType.SYSTEM) == 200
|
||||
|
||||
def test_deallocate_below_zero(self) -> None:
|
||||
"""Test deallocation doesn't go below zero."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
)
|
||||
|
||||
budget.allocate(ContextType.SYSTEM, 100)
|
||||
budget.deallocate(ContextType.SYSTEM, 200)
|
||||
assert budget.get_used(ContextType.SYSTEM) == 0
|
||||
|
||||
def test_total_remaining(self) -> None:
|
||||
"""Test total remaining calculation."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
knowledge=4000,
|
||||
response_reserve=1500,
|
||||
buffer=500,
|
||||
)
|
||||
|
||||
# Usable = total - response_reserve - buffer = 10000 - 1500 - 500 = 8000
|
||||
assert budget.total_remaining() == 8000
|
||||
|
||||
# After allocation
|
||||
budget.allocate(ContextType.SYSTEM, 200)
|
||||
assert budget.total_remaining() == 7800
|
||||
|
||||
def test_utilization(self) -> None:
|
||||
"""Test utilization calculation."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
response_reserve=1500,
|
||||
buffer=500,
|
||||
)
|
||||
|
||||
# No usage = 0%
|
||||
assert budget.utilization(ContextType.SYSTEM) == 0.0
|
||||
|
||||
# Half used = 50%
|
||||
budget.allocate(ContextType.SYSTEM, 250)
|
||||
assert budget.utilization(ContextType.SYSTEM) == 0.5
|
||||
|
||||
# Total utilization
|
||||
assert budget.utilization() == 250 / 8000 # 250 / (10000 - 1500 - 500)
|
||||
|
||||
def test_reset(self) -> None:
|
||||
"""Test reset clears usage."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
)
|
||||
|
||||
budget.allocate(ContextType.SYSTEM, 300)
|
||||
assert budget.get_used(ContextType.SYSTEM) == 300
|
||||
|
||||
budget.reset()
|
||||
assert budget.get_used(ContextType.SYSTEM) == 0
|
||||
assert budget.total_used() == 0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict conversion."""
|
||||
budget = TokenBudget(
|
||||
total=10000,
|
||||
system=500,
|
||||
task=1000,
|
||||
knowledge=4000,
|
||||
)
|
||||
|
||||
budget.allocate(ContextType.SYSTEM, 200)
|
||||
|
||||
data = budget.to_dict()
|
||||
assert data["total"] == 10000
|
||||
assert data["allocations"]["system"] == 500
|
||||
assert data["used"]["system"] == 200
|
||||
assert data["remaining"]["system"] == 300
|
||||
|
||||
|
||||
class TestBudgetAllocator:
|
||||
"""Tests for BudgetAllocator."""
|
||||
|
||||
def test_create_budget(self) -> None:
|
||||
"""Test budget creation with default allocations."""
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(100000)
|
||||
|
||||
assert budget.total == 100000
|
||||
assert budget.system == 5000 # 5%
|
||||
assert budget.task == 10000 # 10%
|
||||
assert budget.knowledge == 40000 # 40%
|
||||
assert budget.conversation == 20000 # 20%
|
||||
assert budget.tools == 5000 # 5%
|
||||
assert budget.response_reserve == 15000 # 15%
|
||||
assert budget.buffer == 5000 # 5%
|
||||
|
||||
def test_create_budget_custom_allocations(self) -> None:
|
||||
"""Test budget creation with custom allocations."""
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(
|
||||
100000,
|
||||
custom_allocations={
|
||||
"system": 0.10,
|
||||
"task": 0.10,
|
||||
"knowledge": 0.30,
|
||||
"conversation": 0.25,
|
||||
"tools": 0.05,
|
||||
"response": 0.15,
|
||||
"buffer": 0.05,
|
||||
},
|
||||
)
|
||||
|
||||
assert budget.system == 10000 # 10%
|
||||
assert budget.knowledge == 30000 # 30%
|
||||
|
||||
def test_create_budget_for_model(self) -> None:
|
||||
"""Test budget creation for specific model."""
|
||||
allocator = BudgetAllocator()
|
||||
|
||||
# Claude models have 200k context
|
||||
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||
assert budget.total == 200000
|
||||
|
||||
# GPT-4 has 8k context
|
||||
budget = allocator.create_budget_for_model("gpt-4")
|
||||
assert budget.total == 8192
|
||||
|
||||
# GPT-4-turbo has 128k context
|
||||
budget = allocator.create_budget_for_model("gpt-4-turbo")
|
||||
assert budget.total == 128000
|
||||
|
||||
def test_get_model_context_size(self) -> None:
|
||||
"""Test model context size lookup."""
|
||||
allocator = BudgetAllocator()
|
||||
|
||||
# Known models
|
||||
assert allocator.get_model_context_size("claude-3-opus") == 200000
|
||||
assert allocator.get_model_context_size("gpt-4") == 8192
|
||||
assert allocator.get_model_context_size("gemini-1.5-pro") == 2000000
|
||||
|
||||
# Unknown model gets default
|
||||
assert allocator.get_model_context_size("unknown-model") == 8192
|
||||
|
||||
def test_adjust_budget(self) -> None:
|
||||
"""Test budget adjustment."""
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
original_system = budget.system
|
||||
original_buffer = budget.buffer
|
||||
|
||||
# Increase system by taking from buffer
|
||||
budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 200)
|
||||
|
||||
assert budget.system == original_system + 200
|
||||
assert budget.buffer == original_buffer - 200
|
||||
|
||||
def test_adjust_budget_limited_by_buffer(self) -> None:
|
||||
"""Test that adjustment is limited by buffer size."""
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
original_buffer = budget.buffer
|
||||
|
||||
# Try to increase more than buffer allows
|
||||
budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 10000)
|
||||
|
||||
# Should only increase by buffer amount
|
||||
assert budget.buffer == 0
|
||||
assert budget.system <= original_buffer + budget.system
|
||||
|
||||
def test_rebalance_budget(self) -> None:
|
||||
"""Test budget rebalancing."""
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
# Use most of knowledge budget
|
||||
budget.allocate(ContextType.KNOWLEDGE, 3500)
|
||||
|
||||
# Rebalance prioritizing knowledge
|
||||
budget = allocator.rebalance_budget(
|
||||
budget,
|
||||
prioritize=[ContextType.KNOWLEDGE],
|
||||
)
|
||||
|
||||
# Knowledge should have gotten more tokens
|
||||
# (This is a fuzzy test - just check it runs)
|
||||
assert budget is not None
|
||||
|
||||
|
||||
class TestTokenCalculator:
|
||||
"""Tests for TokenCalculator."""
|
||||
|
||||
def test_estimate_tokens(self) -> None:
|
||||
"""Test token estimation."""
|
||||
calc = TokenCalculator()
|
||||
|
||||
# Empty string
|
||||
assert calc.estimate_tokens("") == 0
|
||||
|
||||
# Short text (~4 chars per token)
|
||||
text = "This is a test message"
|
||||
estimate = calc.estimate_tokens(text)
|
||||
assert 4 <= estimate <= 8
|
||||
|
||||
def test_estimate_tokens_model_specific(self) -> None:
|
||||
"""Test model-specific estimation ratios."""
|
||||
calc = TokenCalculator()
|
||||
text = "a" * 100
|
||||
|
||||
# Claude uses 3.5 chars per token
|
||||
claude_estimate = calc.estimate_tokens(text, "claude-3-sonnet")
|
||||
# GPT uses 4.0 chars per token
|
||||
gpt_estimate = calc.estimate_tokens(text, "gpt-4")
|
||||
|
||||
# Claude should estimate more tokens (smaller ratio)
|
||||
assert claude_estimate >= gpt_estimate
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens_no_mcp(self) -> None:
|
||||
"""Test token counting without MCP (fallback to estimation)."""
|
||||
calc = TokenCalculator()
|
||||
|
||||
text = "This is a test"
|
||||
count = await calc.count_tokens(text)
|
||||
|
||||
# Should use estimation
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens_with_mcp_success(self) -> None:
|
||||
"""Test token counting with MCP integration."""
|
||||
# Mock MCP manager
|
||||
mock_mcp = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.success = True
|
||||
mock_result.data = {"token_count": 42}
|
||||
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
calc = TokenCalculator(mcp_manager=mock_mcp)
|
||||
count = await calc.count_tokens("test text")
|
||||
|
||||
assert count == 42
|
||||
mock_mcp.call_tool.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens_with_mcp_failure(self) -> None:
|
||||
"""Test fallback when MCP fails."""
|
||||
# Mock MCP manager that fails
|
||||
mock_mcp = MagicMock()
|
||||
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
calc = TokenCalculator(mcp_manager=mock_mcp)
|
||||
count = await calc.count_tokens("test text")
|
||||
|
||||
# Should fall back to estimation
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens_caching(self) -> None:
|
||||
"""Test that token counts are cached."""
|
||||
mock_mcp = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.success = True
|
||||
mock_result.data = {"token_count": 42}
|
||||
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
calc = TokenCalculator(mcp_manager=mock_mcp)
|
||||
|
||||
# First call
|
||||
count1 = await calc.count_tokens("test text")
|
||||
# Second call (should use cache)
|
||||
count2 = await calc.count_tokens("test text")
|
||||
|
||||
assert count1 == count2 == 42
|
||||
# MCP should only be called once
|
||||
assert mock_mcp.call_tool.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens_batch(self) -> None:
|
||||
"""Test batch token counting."""
|
||||
calc = TokenCalculator()
|
||||
|
||||
texts = ["Hello", "World", "Test message here"]
|
||||
counts = await calc.count_tokens_batch(texts)
|
||||
|
||||
assert len(counts) == 3
|
||||
assert all(c > 0 for c in counts)
|
||||
|
||||
def test_cache_stats(self) -> None:
|
||||
"""Test cache statistics."""
|
||||
calc = TokenCalculator()
|
||||
|
||||
stats = calc.get_cache_stats()
|
||||
assert stats["enabled"] is True
|
||||
assert stats["size"] == 0
|
||||
assert stats["hits"] == 0
|
||||
assert stats["misses"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_rate(self) -> None:
|
||||
"""Test cache hit rate tracking."""
|
||||
calc = TokenCalculator()
|
||||
|
||||
# Make some calls
|
||||
await calc.count_tokens("text1")
|
||||
await calc.count_tokens("text2")
|
||||
await calc.count_tokens("text1") # Cache hit
|
||||
|
||||
stats = calc.get_cache_stats()
|
||||
assert stats["hits"] == 1
|
||||
assert stats["misses"] == 2
|
||||
|
||||
def test_clear_cache(self) -> None:
|
||||
"""Test cache clearing."""
|
||||
calc = TokenCalculator()
|
||||
calc._cache["test"] = 100
|
||||
calc._cache_hits = 5
|
||||
|
||||
calc.clear_cache()
|
||||
|
||||
assert len(calc._cache) == 0
|
||||
assert calc._cache_hits == 0
|
||||
|
||||
def test_set_mcp_manager(self) -> None:
|
||||
"""Test setting MCP manager after initialization."""
|
||||
calc = TokenCalculator()
|
||||
assert calc._mcp is None
|
||||
|
||||
mock_mcp = MagicMock()
|
||||
calc.set_mcp_manager(mock_mcp)
|
||||
|
||||
assert calc._mcp is mock_mcp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_token_count_formats(self) -> None:
|
||||
"""Test parsing different token count response formats."""
|
||||
calc = TokenCalculator()
|
||||
|
||||
# Dict with token_count
|
||||
assert calc._parse_token_count({"token_count": 42}) == 42
|
||||
|
||||
# Dict with tokens
|
||||
assert calc._parse_token_count({"tokens": 42}) == 42
|
||||
|
||||
# Dict with count
|
||||
assert calc._parse_token_count({"count": 42}) == 42
|
||||
|
||||
# Direct int
|
||||
assert calc._parse_token_count(42) == 42
|
||||
|
||||
# JSON string
|
||||
assert calc._parse_token_count('{"token_count": 42}') == 42
|
||||
|
||||
# Invalid
|
||||
assert calc._parse_token_count("invalid") is None
|
||||
|
||||
|
||||
class TestBudgetIntegration:
|
||||
"""Integration tests for budget management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_budget_workflow(self) -> None:
|
||||
"""Test complete budget allocation workflow."""
|
||||
# Create settings and allocator
|
||||
settings = ContextSettings()
|
||||
allocator = BudgetAllocator(settings)
|
||||
|
||||
# Create budget for Claude
|
||||
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||
assert budget.total == 200000
|
||||
|
||||
# Create calculator (without MCP for test)
|
||||
calc = TokenCalculator()
|
||||
|
||||
# Simulate context allocation
|
||||
system_text = "You are a helpful assistant." * 10
|
||||
system_tokens = await calc.count_tokens(system_text)
|
||||
|
||||
# Allocate
|
||||
assert budget.can_fit(ContextType.SYSTEM, system_tokens)
|
||||
budget.allocate(ContextType.SYSTEM, system_tokens)
|
||||
|
||||
# Check state
|
||||
assert budget.get_used(ContextType.SYSTEM) == system_tokens
|
||||
assert budget.remaining(ContextType.SYSTEM) == budget.system - system_tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_overflow_handling(self) -> None:
|
||||
"""Test handling budget overflow."""
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(1000) # Small budget
|
||||
|
||||
# Try to allocate more than available
|
||||
with pytest.raises(BudgetExceededError):
|
||||
budget.allocate(ContextType.KNOWLEDGE, 500)
|
||||
|
||||
# Force allocation should work
|
||||
budget.allocate(ContextType.KNOWLEDGE, 500, force=True)
|
||||
assert budget.get_used(ContextType.KNOWLEDGE) == 500
|
||||
479
backend/tests/services/context/test_cache.py
Normal file
479
backend/tests/services/context/test_cache.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""Tests for context cache module."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.cache import ContextCache
|
||||
from app.services.context.config import ContextSettings
|
||||
from app.services.context.exceptions import CacheError
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
ContextPriority,
|
||||
KnowledgeContext,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
|
||||
class TestContextCacheBasics:
|
||||
"""Basic tests for ContextCache."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test cache creation without Redis."""
|
||||
cache = ContextCache()
|
||||
assert cache._redis is None
|
||||
assert not cache.is_enabled
|
||||
|
||||
def test_creation_with_settings(self) -> None:
|
||||
"""Test cache creation with custom settings."""
|
||||
settings = ContextSettings(
|
||||
cache_prefix="test",
|
||||
cache_ttl_seconds=60,
|
||||
)
|
||||
cache = ContextCache(settings=settings)
|
||||
assert cache._prefix == "test"
|
||||
assert cache._ttl == 60
|
||||
|
||||
def test_set_redis(self) -> None:
|
||||
"""Test setting Redis connection."""
|
||||
cache = ContextCache()
|
||||
mock_redis = MagicMock()
|
||||
cache.set_redis(mock_redis)
|
||||
assert cache._redis is mock_redis
|
||||
|
||||
def test_is_enabled(self) -> None:
|
||||
"""Test is_enabled property."""
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(settings=settings)
|
||||
assert not cache.is_enabled # No Redis
|
||||
|
||||
cache.set_redis(MagicMock())
|
||||
assert cache.is_enabled
|
||||
|
||||
# Disabled in settings
|
||||
settings2 = ContextSettings(cache_enabled=False)
|
||||
cache2 = ContextCache(redis=MagicMock(), settings=settings2)
|
||||
assert not cache2.is_enabled
|
||||
|
||||
def test_cache_key(self) -> None:
|
||||
"""Test cache key generation."""
|
||||
cache = ContextCache()
|
||||
key = cache._cache_key("assembled", "abc123")
|
||||
assert key == "ctx:assembled:abc123"
|
||||
|
||||
def test_hash_content(self) -> None:
|
||||
"""Test content hashing."""
|
||||
hash1 = ContextCache._hash_content("hello world")
|
||||
hash2 = ContextCache._hash_content("hello world")
|
||||
hash3 = ContextCache._hash_content("different")
|
||||
|
||||
assert hash1 == hash2
|
||||
assert hash1 != hash3
|
||||
assert len(hash1) == 32
|
||||
|
||||
|
||||
class TestFingerprintComputation:
|
||||
"""Tests for fingerprint computation."""
|
||||
|
||||
def test_compute_fingerprint(self) -> None:
|
||||
"""Test fingerprint computation."""
|
||||
cache = ContextCache()
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System", source="system"),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
|
||||
fp1 = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||
fp2 = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||
fp3 = cache.compute_fingerprint(contexts, "different", "claude-3")
|
||||
|
||||
assert fp1 == fp2 # Same inputs = same fingerprint
|
||||
assert fp1 != fp3 # Different query = different fingerprint
|
||||
assert len(fp1) == 32
|
||||
|
||||
def test_fingerprint_includes_priority(self) -> None:
|
||||
"""Test that fingerprint changes with priority."""
|
||||
cache = ContextCache()
|
||||
|
||||
# Use KnowledgeContext since SystemContext has __post_init__ that may override
|
||||
ctx1 = [
|
||||
KnowledgeContext(
|
||||
content="Knowledge",
|
||||
source="docs",
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
]
|
||||
ctx2 = [
|
||||
KnowledgeContext(
|
||||
content="Knowledge",
|
||||
source="docs",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
]
|
||||
|
||||
fp1 = cache.compute_fingerprint(ctx1, "query", "claude-3")
|
||||
fp2 = cache.compute_fingerprint(ctx2, "query", "claude-3")
|
||||
|
||||
assert fp1 != fp2
|
||||
|
||||
def test_fingerprint_includes_model(self) -> None:
|
||||
"""Test that fingerprint changes with model."""
|
||||
cache = ContextCache()
|
||||
contexts = [SystemContext(content="System", source="system")]
|
||||
|
||||
fp1 = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||
fp2 = cache.compute_fingerprint(contexts, "query", "gpt-4")
|
||||
|
||||
assert fp1 != fp2
|
||||
|
||||
|
||||
class TestMemoryCache:
|
||||
"""Tests for in-memory caching."""
|
||||
|
||||
def test_memory_cache_fallback(self) -> None:
|
||||
"""Test memory cache when Redis unavailable."""
|
||||
cache = ContextCache()
|
||||
|
||||
# Should use memory cache
|
||||
cache._set_memory("test-key", "42")
|
||||
assert "test-key" in cache._memory_cache
|
||||
assert cache._memory_cache["test-key"][0] == "42"
|
||||
|
||||
def test_memory_cache_eviction(self) -> None:
|
||||
"""Test memory cache eviction."""
|
||||
cache = ContextCache()
|
||||
cache._max_memory_items = 10
|
||||
|
||||
# Fill cache
|
||||
for i in range(15):
|
||||
cache._set_memory(f"key-{i}", f"value-{i}")
|
||||
|
||||
# Should have evicted some items
|
||||
assert len(cache._memory_cache) < 15
|
||||
|
||||
|
||||
class TestAssembledContextCache:
|
||||
"""Tests for assembled context caching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_assembled_no_redis(self) -> None:
|
||||
"""Test get_assembled without Redis returns None."""
|
||||
cache = ContextCache()
|
||||
result = await cache.get_assembled("fingerprint")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_assembled_not_found(self) -> None:
|
||||
"""Test get_assembled when key not found."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
result = await cache.get_assembled("fingerprint")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_assembled_found(self) -> None:
|
||||
"""Test get_assembled when key found."""
|
||||
# Create a context
|
||||
ctx = AssembledContext(
|
||||
content="Test content",
|
||||
total_tokens=100,
|
||||
context_count=2,
|
||||
)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get.return_value = ctx.to_json()
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
result = await cache.get_assembled("fingerprint")
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "Test content"
|
||||
assert result.total_tokens == 100
|
||||
assert result.cache_hit is True
|
||||
assert result.cache_key == "fingerprint"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_assembled(self) -> None:
|
||||
"""Test set_assembled."""
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=60)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
ctx = AssembledContext(
|
||||
content="Test content",
|
||||
total_tokens=100,
|
||||
context_count=2,
|
||||
)
|
||||
|
||||
await cache.set_assembled("fingerprint", ctx)
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
call_args = mock_redis.setex.call_args
|
||||
assert call_args[0][0] == "ctx:assembled:fingerprint"
|
||||
assert call_args[0][1] == 60 # TTL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_assembled_custom_ttl(self) -> None:
|
||||
"""Test set_assembled with custom TTL."""
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
ctx = AssembledContext(
|
||||
content="Test",
|
||||
total_tokens=10,
|
||||
context_count=1,
|
||||
)
|
||||
|
||||
await cache.set_assembled("fp", ctx, ttl=120)
|
||||
|
||||
call_args = mock_redis.setex.call_args
|
||||
assert call_args[0][1] == 120
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_error_on_get(self) -> None:
|
||||
"""Test CacheError raised on Redis error."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get.side_effect = Exception("Redis error")
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
with pytest.raises(CacheError):
|
||||
await cache.get_assembled("fingerprint")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_error_on_set(self) -> None:
|
||||
"""Test CacheError raised on Redis error."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.setex.side_effect = Exception("Redis error")
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
ctx = AssembledContext(
|
||||
content="Test",
|
||||
total_tokens=10,
|
||||
context_count=1,
|
||||
)
|
||||
|
||||
with pytest.raises(CacheError):
|
||||
await cache.set_assembled("fp", ctx)
|
||||
|
||||
|
||||
class TestTokenCountCache:
|
||||
"""Tests for token count caching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_token_count_memory_fallback(self) -> None:
|
||||
"""Test get_token_count uses memory cache."""
|
||||
cache = ContextCache()
|
||||
|
||||
# Set in memory
|
||||
key = cache._cache_key("tokens", "default", cache._hash_content("hello"))
|
||||
cache._set_memory(key, "42")
|
||||
|
||||
result = await cache.get_token_count("hello")
|
||||
assert result == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_token_count_memory(self) -> None:
|
||||
"""Test set_token_count stores in memory."""
|
||||
cache = ContextCache()
|
||||
|
||||
await cache.set_token_count("hello", 42)
|
||||
|
||||
result = await cache.get_token_count("hello")
|
||||
assert result == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_token_count_with_model(self) -> None:
|
||||
"""Test set_token_count with model-specific tokenization."""
|
||||
mock_redis = AsyncMock()
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
await cache.set_token_count("hello", 42, model="claude-3")
|
||||
await cache.set_token_count("hello", 50, model="gpt-4")
|
||||
|
||||
# Different models should have different keys
|
||||
assert mock_redis.setex.call_count == 2
|
||||
calls = mock_redis.setex.call_args_list
|
||||
|
||||
key1 = calls[0][0][0]
|
||||
key2 = calls[1][0][0]
|
||||
assert "claude-3" in key1
|
||||
assert "gpt-4" in key2
|
||||
|
||||
|
||||
class TestScoreCache:
|
||||
"""Tests for score caching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_score_memory_fallback(self) -> None:
|
||||
"""Test get_score uses memory cache."""
|
||||
cache = ContextCache()
|
||||
|
||||
# Set in memory
|
||||
query_hash = cache._hash_content("query")[:16]
|
||||
key = cache._cache_key("score", "relevance", "ctx-123", query_hash)
|
||||
cache._set_memory(key, "0.85")
|
||||
|
||||
result = await cache.get_score("relevance", "ctx-123", "query")
|
||||
assert result == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_score_memory(self) -> None:
|
||||
"""Test set_score stores in memory."""
|
||||
cache = ContextCache()
|
||||
|
||||
await cache.set_score("relevance", "ctx-123", "query", 0.85)
|
||||
|
||||
result = await cache.get_score("relevance", "ctx-123", "query")
|
||||
assert result == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_score_with_redis(self) -> None:
|
||||
"""Test set_score with Redis."""
|
||||
mock_redis = AsyncMock()
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
await cache.set_score("relevance", "ctx-123", "query", 0.85)
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
|
||||
class TestCacheInvalidation:
|
||||
"""Tests for cache invalidation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_pattern(self) -> None:
|
||||
"""Test invalidate with pattern."""
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
# Set up scan_iter to return matching keys
|
||||
async def mock_scan_iter(match=None):
|
||||
for key in ["ctx:assembled:1", "ctx:assembled:2"]:
|
||||
yield key
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
deleted = await cache.invalidate("assembled:*")
|
||||
|
||||
assert deleted == 2
|
||||
assert mock_redis.delete.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_all(self) -> None:
|
||||
"""Test clear_all."""
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
async def mock_scan_iter(match=None):
|
||||
for key in ["ctx:1", "ctx:2", "ctx:3"]:
|
||||
yield key
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
# Add to memory cache
|
||||
cache._set_memory("test", "value")
|
||||
assert len(cache._memory_cache) > 0
|
||||
|
||||
deleted = await cache.clear_all()
|
||||
|
||||
assert deleted == 3
|
||||
assert len(cache._memory_cache) == 0
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Tests for cache statistics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_no_redis(self) -> None:
|
||||
"""Test get_stats without Redis."""
|
||||
cache = ContextCache()
|
||||
cache._set_memory("key", "value")
|
||||
|
||||
stats = await cache.get_stats()
|
||||
|
||||
assert stats["enabled"] is True
|
||||
assert stats["redis_available"] is False
|
||||
assert stats["memory_items"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_with_redis(self) -> None:
|
||||
"""Test get_stats with Redis."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.info.return_value = {"used_memory_human": "1.5M"}
|
||||
|
||||
settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=300)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
stats = await cache.get_stats()
|
||||
|
||||
assert stats["enabled"] is True
|
||||
assert stats["redis_available"] is True
|
||||
assert stats["ttl_seconds"] == 300
|
||||
assert stats["redis_memory_used"] == "1.5M"
|
||||
|
||||
|
||||
class TestCacheIntegration:
|
||||
"""Integration tests for cache."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_workflow(self) -> None:
|
||||
"""Test complete cache workflow."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||
|
||||
contexts = [
|
||||
SystemContext(content="System", source="system"),
|
||||
KnowledgeContext(content="Knowledge", source="docs"),
|
||||
]
|
||||
|
||||
# Compute fingerprint
|
||||
fp = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||
assert len(fp) == 32
|
||||
|
||||
# Check cache (miss)
|
||||
result = await cache.get_assembled(fp)
|
||||
assert result is None
|
||||
|
||||
# Create and cache assembled context
|
||||
assembled = AssembledContext(
|
||||
content="Assembled content",
|
||||
total_tokens=100,
|
||||
context_count=2,
|
||||
model="claude-3",
|
||||
)
|
||||
await cache.set_assembled(fp, assembled)
|
||||
|
||||
# Verify setex was called
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
# Mock cache hit
|
||||
mock_redis.get.return_value = assembled.to_json()
|
||||
result = await cache.get_assembled(fp)
|
||||
|
||||
assert result is not None
|
||||
assert result.cache_hit is True
|
||||
assert result.content == "Assembled content"
|
||||
294
backend/tests/services/context/test_compression.py
Normal file
294
backend/tests/services/context/test_compression.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Tests for context compression module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.budget import BudgetAllocator
|
||||
from app.services.context.compression import (
|
||||
ContextCompressor,
|
||||
TruncationResult,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from app.services.context.types import (
|
||||
ContextType,
|
||||
KnowledgeContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
|
||||
class TestTruncationResult:
|
||||
"""Tests for TruncationResult dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
result = TruncationResult(
|
||||
original_tokens=100,
|
||||
truncated_tokens=50,
|
||||
content="Truncated content",
|
||||
truncated=True,
|
||||
truncation_ratio=0.5,
|
||||
)
|
||||
|
||||
assert result.original_tokens == 100
|
||||
assert result.truncated_tokens == 50
|
||||
assert result.truncated is True
|
||||
assert result.truncation_ratio == 0.5
|
||||
|
||||
def test_tokens_saved(self) -> None:
|
||||
"""Test tokens_saved property."""
|
||||
result = TruncationResult(
|
||||
original_tokens=100,
|
||||
truncated_tokens=40,
|
||||
content="Test",
|
||||
truncated=True,
|
||||
truncation_ratio=0.6,
|
||||
)
|
||||
|
||||
assert result.tokens_saved == 60
|
||||
|
||||
def test_no_truncation(self) -> None:
|
||||
"""Test when no truncation needed."""
|
||||
result = TruncationResult(
|
||||
original_tokens=50,
|
||||
truncated_tokens=50,
|
||||
content="Full content",
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
assert result.tokens_saved == 0
|
||||
assert result.truncated is False
|
||||
|
||||
|
||||
class TestTruncationStrategy:
|
||||
"""Tests for TruncationStrategy."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test strategy creation."""
|
||||
strategy = TruncationStrategy()
|
||||
assert strategy._preserve_ratio_start == 0.7
|
||||
assert strategy._min_content_length == 100
|
||||
|
||||
def test_creation_with_params(self) -> None:
|
||||
"""Test strategy creation with custom params."""
|
||||
strategy = TruncationStrategy(
|
||||
preserve_ratio_start=0.5,
|
||||
min_content_length=50,
|
||||
)
|
||||
assert strategy._preserve_ratio_start == 0.5
|
||||
assert strategy._min_content_length == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_empty_content(self) -> None:
|
||||
"""Test truncating empty content."""
|
||||
strategy = TruncationStrategy()
|
||||
|
||||
result = await strategy.truncate_to_tokens("", max_tokens=100)
|
||||
|
||||
assert result.original_tokens == 0
|
||||
assert result.truncated_tokens == 0
|
||||
assert result.content == ""
|
||||
assert result.truncated is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_content_within_limit(self) -> None:
|
||||
"""Test content that fits within limit."""
|
||||
strategy = TruncationStrategy()
|
||||
content = "Short content"
|
||||
|
||||
result = await strategy.truncate_to_tokens(content, max_tokens=100)
|
||||
|
||||
assert result.content == content
|
||||
assert result.truncated is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_end_strategy(self) -> None:
|
||||
"""Test end truncation strategy."""
|
||||
strategy = TruncationStrategy()
|
||||
content = "A" * 1000 # Long content
|
||||
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=50, strategy="end"
|
||||
)
|
||||
|
||||
assert result.truncated is True
|
||||
assert len(result.content) < len(content)
|
||||
assert strategy.truncation_marker in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_middle_strategy(self) -> None:
|
||||
"""Test middle truncation strategy."""
|
||||
strategy = TruncationStrategy(preserve_ratio_start=0.6)
|
||||
content = "START " + "A" * 500 + " END"
|
||||
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=50, strategy="middle"
|
||||
)
|
||||
|
||||
assert result.truncated is True
|
||||
assert strategy.truncation_marker in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_sentence_strategy(self) -> None:
|
||||
"""Test sentence-aware truncation strategy."""
|
||||
strategy = TruncationStrategy()
|
||||
content = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=10, strategy="sentence"
|
||||
)
|
||||
|
||||
assert result.truncated is True
|
||||
# Should cut at sentence boundary
|
||||
assert (
|
||||
result.content.endswith(".") or strategy.truncation_marker in result.content
|
||||
)
|
||||
|
||||
|
||||
class TestContextCompressor:
|
||||
"""Tests for ContextCompressor."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test compressor creation."""
|
||||
compressor = ContextCompressor()
|
||||
assert compressor._truncation is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_context_within_limit(self) -> None:
|
||||
"""Test compressing context that already fits."""
|
||||
compressor = ContextCompressor()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Short content",
|
||||
source="docs",
|
||||
)
|
||||
context.token_count = 5
|
||||
|
||||
result = await compressor.compress_context(context, max_tokens=100)
|
||||
|
||||
# Should return same context unmodified
|
||||
assert result.content == "Short content"
|
||||
assert result.metadata.get("truncated") is not True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_context_exceeds_limit(self) -> None:
|
||||
"""Test compressing context that exceeds limit."""
|
||||
compressor = ContextCompressor()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="A" * 500,
|
||||
source="docs",
|
||||
)
|
||||
context.token_count = 125 # Approximately 500/4
|
||||
|
||||
result = await compressor.compress_context(context, max_tokens=20)
|
||||
|
||||
assert result.metadata.get("truncated") is True
|
||||
assert result.metadata.get("original_tokens") == 125
|
||||
assert len(result.content) < 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_contexts_batch(self) -> None:
|
||||
"""Test compressing multiple contexts."""
|
||||
compressor = ContextCompressor()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(1000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(content="A" * 200, source="docs"),
|
||||
KnowledgeContext(content="B" * 200, source="docs"),
|
||||
TaskContext(content="C" * 200, source="task"),
|
||||
]
|
||||
|
||||
result = await compressor.compress_contexts(contexts, budget)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strategy_selection_by_type(self) -> None:
|
||||
"""Test that correct strategy is selected for each type."""
|
||||
compressor = ContextCompressor()
|
||||
|
||||
assert compressor._get_strategy_for_type(ContextType.SYSTEM) == "end"
|
||||
assert compressor._get_strategy_for_type(ContextType.TASK) == "end"
|
||||
assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
|
||||
assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
|
||||
assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"
|
||||
|
||||
|
||||
class TestTruncationEdgeCases:
|
||||
"""Tests for edge cases in truncation to prevent regressions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_ratio_with_zero_original_tokens(self) -> None:
|
||||
"""Test that truncation ratio handles zero original tokens without division by zero."""
|
||||
strategy = TruncationStrategy()
|
||||
|
||||
# Empty content should not raise ZeroDivisionError
|
||||
result = await strategy.truncate_to_tokens("", max_tokens=100)
|
||||
|
||||
assert result.truncation_ratio == 0.0
|
||||
assert result.original_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_end_with_zero_available_tokens(self) -> None:
|
||||
"""Test truncation when marker tokens exceed max_tokens."""
|
||||
strategy = TruncationStrategy()
|
||||
content = "Some content to truncate"
|
||||
|
||||
# max_tokens less than marker tokens should return just marker
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=1, strategy="end"
|
||||
)
|
||||
|
||||
# Should handle gracefully without crashing
|
||||
assert strategy.truncation_marker in result.content or result.content == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
|
||||
"""Test truncation when content estimates to zero tokens."""
|
||||
strategy = TruncationStrategy()
|
||||
|
||||
# Very short content that might estimate to 0 tokens
|
||||
result = await strategy.truncate_to_tokens("a", max_tokens=100)
|
||||
|
||||
# Should not raise ZeroDivisionError
|
||||
assert result.content in ("a", "a" + strategy.truncation_marker)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_for_tokens_zero_target(self) -> None:
|
||||
"""Test _get_content_for_tokens with zero target tokens."""
|
||||
strategy = TruncationStrategy()
|
||||
|
||||
result = await strategy._get_content_for_tokens(
|
||||
content="Some content",
|
||||
target_tokens=0,
|
||||
from_start=True,
|
||||
)
|
||||
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sentence_truncation_with_no_sentences(self) -> None:
|
||||
"""Test sentence truncation with content that has no sentence boundaries."""
|
||||
strategy = TruncationStrategy()
|
||||
content = "this is content without any sentence ending punctuation"
|
||||
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=5, strategy="sentence"
|
||||
)
|
||||
|
||||
# Should handle gracefully
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middle_truncation_very_short_content(self) -> None:
|
||||
"""Test middle truncation with content shorter than preserved portions."""
|
||||
strategy = TruncationStrategy(preserve_ratio_start=0.7)
|
||||
content = "ab" # Very short
|
||||
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=1, strategy="middle"
|
||||
)
|
||||
|
||||
# Should handle gracefully without negative indices
|
||||
assert result is not None
|
||||
243
backend/tests/services/context/test_config.py
Normal file
243
backend/tests/services/context/test_config.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Tests for context management configuration."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.config import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
get_default_settings,
|
||||
reset_context_settings,
|
||||
)
|
||||
|
||||
|
||||
class TestContextSettings:
|
||||
"""Tests for ContextSettings."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default settings values."""
|
||||
settings = ContextSettings()
|
||||
|
||||
# Budget defaults should sum to 1.0
|
||||
total = (
|
||||
settings.budget_system
|
||||
+ settings.budget_task
|
||||
+ settings.budget_knowledge
|
||||
+ settings.budget_conversation
|
||||
+ settings.budget_tools
|
||||
+ settings.budget_response
|
||||
+ settings.budget_buffer
|
||||
)
|
||||
assert abs(total - 1.0) < 0.001
|
||||
|
||||
# Scoring weights should sum to 1.0
|
||||
weights_total = (
|
||||
settings.scoring_relevance_weight
|
||||
+ settings.scoring_recency_weight
|
||||
+ settings.scoring_priority_weight
|
||||
)
|
||||
assert abs(weights_total - 1.0) < 0.001
|
||||
|
||||
def test_budget_allocation_values(self) -> None:
|
||||
"""Test specific budget allocation values."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.budget_system == 0.05
|
||||
assert settings.budget_task == 0.10
|
||||
assert settings.budget_knowledge == 0.40
|
||||
assert settings.budget_conversation == 0.20
|
||||
assert settings.budget_tools == 0.05
|
||||
assert settings.budget_response == 0.15
|
||||
assert settings.budget_buffer == 0.05
|
||||
|
||||
def test_scoring_weights(self) -> None:
|
||||
"""Test scoring weights."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.scoring_relevance_weight == 0.5
|
||||
assert settings.scoring_recency_weight == 0.3
|
||||
assert settings.scoring_priority_weight == 0.2
|
||||
|
||||
def test_cache_settings(self) -> None:
|
||||
"""Test cache settings."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.cache_enabled is True
|
||||
assert settings.cache_ttl_seconds == 3600
|
||||
assert settings.cache_prefix == "ctx"
|
||||
|
||||
def test_performance_settings(self) -> None:
|
||||
"""Test performance settings."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.max_assembly_time_ms == 2000
|
||||
assert settings.parallel_scoring is True
|
||||
assert settings.max_parallel_scores == 10
|
||||
|
||||
def test_get_budget_allocation(self) -> None:
|
||||
"""Test get_budget_allocation method."""
|
||||
settings = ContextSettings()
|
||||
allocation = settings.get_budget_allocation()
|
||||
|
||||
assert isinstance(allocation, dict)
|
||||
assert "system" in allocation
|
||||
assert "knowledge" in allocation
|
||||
assert allocation["system"] == 0.05
|
||||
assert allocation["knowledge"] == 0.40
|
||||
|
||||
def test_get_scoring_weights(self) -> None:
|
||||
"""Test get_scoring_weights method."""
|
||||
settings = ContextSettings()
|
||||
weights = settings.get_scoring_weights()
|
||||
|
||||
assert isinstance(weights, dict)
|
||||
assert "relevance" in weights
|
||||
assert "recency" in weights
|
||||
assert "priority" in weights
|
||||
assert weights["relevance"] == 0.5
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict method."""
|
||||
settings = ContextSettings()
|
||||
result = settings.to_dict()
|
||||
|
||||
assert "budget" in result
|
||||
assert "scoring" in result
|
||||
assert "compression" in result
|
||||
assert "cache" in result
|
||||
assert "performance" in result
|
||||
assert "knowledge" in result
|
||||
assert "conversation" in result
|
||||
|
||||
def test_budget_validation_fails_on_wrong_sum(self) -> None:
|
||||
"""Test that budget validation fails when sum != 1.0."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ContextSettings(
|
||||
budget_system=0.5,
|
||||
budget_task=0.5,
|
||||
# Other budgets default to non-zero, so total > 1.0
|
||||
)
|
||||
|
||||
assert "sum to 1.0" in str(exc_info.value)
|
||||
|
||||
def test_scoring_validation_fails_on_wrong_sum(self) -> None:
|
||||
"""Test that scoring validation fails when sum != 1.0."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ContextSettings(
|
||||
scoring_relevance_weight=0.8,
|
||||
scoring_recency_weight=0.8,
|
||||
scoring_priority_weight=0.8,
|
||||
)
|
||||
|
||||
assert "sum to 1.0" in str(exc_info.value)
|
||||
|
||||
def test_search_type_validation(self) -> None:
|
||||
"""Test search type validation."""
|
||||
# Valid types should work
|
||||
ContextSettings(knowledge_search_type="semantic")
|
||||
ContextSettings(knowledge_search_type="keyword")
|
||||
ContextSettings(knowledge_search_type="hybrid")
|
||||
|
||||
# Invalid type should fail
|
||||
with pytest.raises(ValueError):
|
||||
ContextSettings(knowledge_search_type="invalid")
|
||||
|
||||
def test_custom_budget_allocation(self) -> None:
|
||||
"""Test custom budget allocation that sums to 1.0."""
|
||||
settings = ContextSettings(
|
||||
budget_system=0.10,
|
||||
budget_task=0.10,
|
||||
budget_knowledge=0.30,
|
||||
budget_conversation=0.25,
|
||||
budget_tools=0.05,
|
||||
budget_response=0.15,
|
||||
budget_buffer=0.05,
|
||||
)
|
||||
|
||||
total = (
|
||||
settings.budget_system
|
||||
+ settings.budget_task
|
||||
+ settings.budget_knowledge
|
||||
+ settings.budget_conversation
|
||||
+ settings.budget_tools
|
||||
+ settings.budget_response
|
||||
+ settings.budget_buffer
|
||||
)
|
||||
assert abs(total - 1.0) < 0.001
|
||||
|
||||
|
||||
class TestSettingsSingleton:
|
||||
"""Tests for settings singleton pattern."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Reset settings before each test."""
|
||||
reset_context_settings()
|
||||
|
||||
def teardown_method(self) -> None:
|
||||
"""Clean up after each test."""
|
||||
reset_context_settings()
|
||||
|
||||
def test_get_context_settings_returns_instance(self) -> None:
|
||||
"""Test that get_context_settings returns a settings instance."""
|
||||
settings = get_context_settings()
|
||||
assert isinstance(settings, ContextSettings)
|
||||
|
||||
def test_get_context_settings_returns_same_instance(self) -> None:
|
||||
"""Test that get_context_settings returns the same instance."""
|
||||
settings1 = get_context_settings()
|
||||
settings2 = get_context_settings()
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_reset_creates_new_instance(self) -> None:
|
||||
"""Test that reset creates a new instance."""
|
||||
settings1 = get_context_settings()
|
||||
reset_context_settings()
|
||||
settings2 = get_context_settings()
|
||||
|
||||
# Should be different instances
|
||||
assert settings1 is not settings2
|
||||
|
||||
def test_get_default_settings_cached(self) -> None:
|
||||
"""Test that get_default_settings is cached."""
|
||||
settings1 = get_default_settings()
|
||||
settings2 = get_default_settings()
|
||||
assert settings1 is settings2
|
||||
|
||||
|
||||
class TestEnvironmentOverrides:
|
||||
"""Tests for environment variable overrides."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Reset settings before each test."""
|
||||
reset_context_settings()
|
||||
|
||||
def teardown_method(self) -> None:
|
||||
"""Clean up after each test."""
|
||||
reset_context_settings()
|
||||
# Clean up any env vars we set
|
||||
for key in list(os.environ.keys()):
|
||||
if key.startswith("CTX_"):
|
||||
del os.environ[key]
|
||||
|
||||
def test_env_override_cache_enabled(self) -> None:
|
||||
"""Test that CTX_CACHE_ENABLED env var works."""
|
||||
with patch.dict(os.environ, {"CTX_CACHE_ENABLED": "false"}):
|
||||
reset_context_settings()
|
||||
settings = ContextSettings()
|
||||
assert settings.cache_enabled is False
|
||||
|
||||
def test_env_override_cache_ttl(self) -> None:
|
||||
"""Test that CTX_CACHE_TTL_SECONDS env var works."""
|
||||
with patch.dict(os.environ, {"CTX_CACHE_TTL_SECONDS": "7200"}):
|
||||
reset_context_settings()
|
||||
settings = ContextSettings()
|
||||
assert settings.cache_ttl_seconds == 7200
|
||||
|
||||
def test_env_override_max_assembly_time(self) -> None:
|
||||
"""Test that CTX_MAX_ASSEMBLY_TIME_MS env var works."""
|
||||
with patch.dict(os.environ, {"CTX_MAX_ASSEMBLY_TIME_MS": "200"}):
|
||||
reset_context_settings()
|
||||
settings = ContextSettings()
|
||||
assert settings.max_assembly_time_ms == 200
|
||||
456
backend/tests/services/context/test_engine.py
Normal file
456
backend/tests/services/context/test_engine.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""Tests for ContextEngine."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.config import ContextSettings
|
||||
from app.services.context.engine import ContextEngine, create_context_engine
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
|
||||
class TestContextEngineCreation:
|
||||
"""Tests for ContextEngine creation."""
|
||||
|
||||
def test_creation_minimal(self) -> None:
|
||||
"""Test creating engine with minimal config."""
|
||||
engine = ContextEngine()
|
||||
|
||||
assert engine._mcp is None
|
||||
assert engine._settings is not None
|
||||
assert engine._calculator is not None
|
||||
assert engine._scorer is not None
|
||||
assert engine._ranker is not None
|
||||
assert engine._compressor is not None
|
||||
assert engine._cache is not None
|
||||
assert engine._pipeline is not None
|
||||
|
||||
def test_creation_with_settings(self) -> None:
|
||||
"""Test creating engine with custom settings."""
|
||||
settings = ContextSettings(
|
||||
compression_threshold=0.7,
|
||||
cache_enabled=False,
|
||||
)
|
||||
engine = ContextEngine(settings=settings)
|
||||
|
||||
assert engine._settings.compression_threshold == 0.7
|
||||
assert engine._settings.cache_enabled is False
|
||||
|
||||
def test_creation_with_redis(self) -> None:
|
||||
"""Test creating engine with Redis."""
|
||||
mock_redis = MagicMock()
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||
|
||||
assert engine._cache.is_enabled
|
||||
|
||||
def test_set_mcp_manager(self) -> None:
|
||||
"""Test setting MCP manager."""
|
||||
engine = ContextEngine()
|
||||
mock_mcp = MagicMock()
|
||||
|
||||
engine.set_mcp_manager(mock_mcp)
|
||||
|
||||
assert engine._mcp is mock_mcp
|
||||
|
||||
def test_set_redis(self) -> None:
|
||||
"""Test setting Redis connection."""
|
||||
engine = ContextEngine()
|
||||
mock_redis = MagicMock()
|
||||
|
||||
engine.set_redis(mock_redis)
|
||||
|
||||
assert engine._cache._redis is mock_redis
|
||||
|
||||
|
||||
class TestContextEngineHelpers:
|
||||
"""Tests for ContextEngine helper methods."""
|
||||
|
||||
def test_convert_conversation(self) -> None:
|
||||
"""Test converting conversation history."""
|
||||
engine = ContextEngine()
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
|
||||
contexts = engine._convert_conversation(history)
|
||||
|
||||
assert len(contexts) == 3
|
||||
assert all(isinstance(c, ConversationContext) for c in contexts)
|
||||
assert contexts[0].role == MessageRole.USER
|
||||
assert contexts[1].role == MessageRole.ASSISTANT
|
||||
assert contexts[0].content == "Hello!"
|
||||
assert contexts[0].metadata["turn"] == 0
|
||||
|
||||
def test_convert_tool_results(self) -> None:
|
||||
"""Test converting tool results."""
|
||||
engine = ContextEngine()
|
||||
|
||||
results = [
|
||||
{"tool_name": "search", "content": "Result 1", "status": "success"},
|
||||
{"tool_name": "read", "result": {"file": "test.txt"}, "status": "success"},
|
||||
]
|
||||
|
||||
contexts = engine._convert_tool_results(results)
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert all(isinstance(c, ToolContext) for c in contexts)
|
||||
assert contexts[0].content == "Result 1"
|
||||
assert contexts[0].metadata["tool_name"] == "search"
|
||||
# Dict content should be JSON serialized
|
||||
assert "file" in contexts[1].content
|
||||
assert "test.txt" in contexts[1].content
|
||||
|
||||
|
||||
class TestContextEngineAssembly:
|
||||
"""Tests for context assembly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_minimal(self) -> None:
|
||||
"""Test assembling with minimal inputs."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test query",
|
||||
model="claude-3-sonnet",
|
||||
use_cache=False, # Disable cache for test
|
||||
)
|
||||
|
||||
assert isinstance(result, AssembledContext)
|
||||
assert result.context_count == 0 # No contexts provided
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_system_prompt(self) -> None:
|
||||
"""Test assembling with system prompt."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test query",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert "helpful assistant" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_task(self) -> None:
|
||||
"""Test assembling with task description."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="implement feature",
|
||||
model="claude-3-sonnet",
|
||||
task_description="Implement user authentication",
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert "authentication" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_conversation(self) -> None:
|
||||
"""Test assembling with conversation history."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="continue",
|
||||
model="claude-3-sonnet",
|
||||
conversation_history=[
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
],
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 2
|
||||
assert "Hello" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_tool_results(self) -> None:
|
||||
"""Test assembling with tool results."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="continue",
|
||||
model="claude-3-sonnet",
|
||||
tool_results=[
|
||||
{"tool_name": "search", "content": "Found 5 results"},
|
||||
],
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert "Found 5 results" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_with_custom_contexts(self) -> None:
|
||||
"""Test assembling with custom contexts."""
|
||||
engine = ContextEngine()
|
||||
|
||||
custom = [
|
||||
KnowledgeContext(
|
||||
content="Custom knowledge.",
|
||||
source="custom",
|
||||
relevance_score=0.9,
|
||||
)
|
||||
]
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
custom_contexts=custom,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count == 1
|
||||
assert "Custom knowledge" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemble_full_workflow(self) -> None:
|
||||
"""Test full assembly workflow."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="implement login",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="You are an expert Python developer.",
|
||||
task_description="Implement user authentication.",
|
||||
conversation_history=[
|
||||
{"role": "user", "content": "Can you help me implement JWT auth?"},
|
||||
],
|
||||
tool_results=[
|
||||
{"tool_name": "file_create", "content": "Created auth.py"},
|
||||
],
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert result.context_count >= 4
|
||||
assert result.total_tokens > 0
|
||||
assert result.model == "claude-3-sonnet"
|
||||
|
||||
# Check for expected content
|
||||
assert "expert Python developer" in result.content
|
||||
assert "authentication" in result.content
|
||||
|
||||
|
||||
class TestContextEngineKnowledge:
|
||||
"""Tests for knowledge fetching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_knowledge_no_mcp(self) -> None:
|
||||
"""Test fetching knowledge without MCP returns empty."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine._fetch_knowledge(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_knowledge_with_mcp(self) -> None:
|
||||
"""Test fetching knowledge with MCP."""
|
||||
mock_mcp = AsyncMock()
|
||||
mock_mcp.call_tool.return_value.data = {
|
||||
"results": [
|
||||
{
|
||||
"content": "Document content",
|
||||
"source_path": "docs/api.md",
|
||||
"score": 0.9,
|
||||
"chunk_id": "chunk-1",
|
||||
},
|
||||
{
|
||||
"content": "Another document",
|
||||
"source_path": "docs/auth.md",
|
||||
"score": 0.8,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
engine = ContextEngine(mcp_manager=mock_mcp)
|
||||
|
||||
result = await engine._fetch_knowledge(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="authentication",
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(c, KnowledgeContext) for c in result)
|
||||
assert result[0].content == "Document content"
|
||||
assert result[0].source == "docs/api.md"
|
||||
assert result[0].relevance_score == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_knowledge_error_handling(self) -> None:
|
||||
"""Test knowledge fetch error handling."""
|
||||
mock_mcp = AsyncMock()
|
||||
mock_mcp.call_tool.side_effect = Exception("MCP error")
|
||||
|
||||
engine = ContextEngine(mcp_manager=mock_mcp)
|
||||
|
||||
# Should not raise, returns empty
|
||||
result = await engine._fetch_knowledge(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestContextEngineCaching:
|
||||
"""Tests for caching behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_disabled(self) -> None:
|
||||
"""Test assembly with cache disabled."""
|
||||
engine = ContextEngine()
|
||||
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="Test prompt",
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
assert not result.cache_hit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit(self) -> None:
|
||||
"""Test cache hit."""
|
||||
mock_redis = AsyncMock()
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||
|
||||
# First call - cache miss
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result1 = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="Test prompt",
|
||||
)
|
||||
|
||||
# Second call - mock cache hit
|
||||
mock_redis.get.return_value = result1.to_json()
|
||||
|
||||
result2 = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="test",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="Test prompt",
|
||||
)
|
||||
|
||||
assert result2.cache_hit
|
||||
|
||||
|
||||
class TestContextEngineUtilities:
|
||||
"""Tests for utility methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_budget_for_model(self) -> None:
|
||||
"""Test getting budget for model."""
|
||||
engine = ContextEngine()
|
||||
|
||||
budget = await engine.get_budget_for_model("claude-3-sonnet")
|
||||
|
||||
assert budget.total > 0
|
||||
assert budget.system > 0
|
||||
assert budget.knowledge > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_budget_with_max_tokens(self) -> None:
|
||||
"""Test getting budget with max tokens."""
|
||||
engine = ContextEngine()
|
||||
|
||||
budget = await engine.get_budget_for_model("claude-3-sonnet", max_tokens=5000)
|
||||
|
||||
assert budget.total == 5000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens(self) -> None:
|
||||
"""Test token counting."""
|
||||
engine = ContextEngine()
|
||||
|
||||
count = await engine.count_tokens("Hello world")
|
||||
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_cache(self) -> None:
|
||||
"""Test cache invalidation."""
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
async def mock_scan_iter(match=None):
|
||||
for key in ["ctx:1", "ctx:2"]:
|
||||
yield key
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
settings = ContextSettings(cache_enabled=True)
|
||||
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||
|
||||
deleted = await engine.invalidate_cache(pattern="*test*")
|
||||
|
||||
assert deleted >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(self) -> None:
|
||||
"""Test getting engine stats."""
|
||||
engine = ContextEngine()
|
||||
|
||||
stats = await engine.get_stats()
|
||||
|
||||
assert "cache" in stats
|
||||
assert "settings" in stats
|
||||
assert "compression_threshold" in stats["settings"]
|
||||
|
||||
|
||||
class TestCreateContextEngine:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_create_context_engine(self) -> None:
|
||||
"""Test factory function."""
|
||||
engine = create_context_engine()
|
||||
|
||||
assert isinstance(engine, ContextEngine)
|
||||
|
||||
def test_create_context_engine_with_settings(self) -> None:
|
||||
"""Test factory with settings."""
|
||||
settings = ContextSettings(cache_enabled=False)
|
||||
engine = create_context_engine(settings=settings)
|
||||
|
||||
assert engine._settings.cache_enabled is False
|
||||
250
backend/tests/services/context/test_exceptions.py
Normal file
250
backend/tests/services/context/test_exceptions.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Tests for context management exceptions."""
|
||||
|
||||
from app.services.context.exceptions import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
CacheError,
|
||||
CompressionError,
|
||||
ContextError,
|
||||
ContextNotFoundError,
|
||||
FormattingError,
|
||||
InvalidContextError,
|
||||
ScoringError,
|
||||
TokenCountError,
|
||||
)
|
||||
|
||||
|
||||
class TestContextError:
|
||||
"""Tests for base ContextError."""
|
||||
|
||||
def test_basic_initialization(self) -> None:
|
||||
"""Test basic error initialization."""
|
||||
error = ContextError("Test error")
|
||||
assert error.message == "Test error"
|
||||
assert error.details == {}
|
||||
assert str(error) == "Test error"
|
||||
|
||||
def test_with_details(self) -> None:
|
||||
"""Test error with details."""
|
||||
error = ContextError("Test error", {"key": "value", "count": 42})
|
||||
assert error.details == {"key": "value", "count": 42}
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
error = ContextError("Test error", {"key": "value"})
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["error_type"] == "ContextError"
|
||||
assert result["message"] == "Test error"
|
||||
assert result["details"] == {"key": "value"}
|
||||
|
||||
def test_inheritance(self) -> None:
|
||||
"""Test that ContextError inherits from Exception."""
|
||||
error = ContextError("Test")
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
class TestBudgetExceededError:
|
||||
"""Tests for BudgetExceededError."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = BudgetExceededError()
|
||||
assert "exceeded" in error.message.lower()
|
||||
|
||||
def test_with_budget_info(self) -> None:
|
||||
"""Test with budget information."""
|
||||
error = BudgetExceededError(
|
||||
allocated=1000,
|
||||
requested=1500,
|
||||
context_type="knowledge",
|
||||
)
|
||||
|
||||
assert error.allocated == 1000
|
||||
assert error.requested == 1500
|
||||
assert error.context_type == "knowledge"
|
||||
assert error.details["overage"] == 500
|
||||
|
||||
def test_to_dict_includes_budget_info(self) -> None:
|
||||
"""Test that to_dict includes budget info."""
|
||||
error = BudgetExceededError(
|
||||
allocated=1000,
|
||||
requested=1500,
|
||||
)
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["details"]["allocated"] == 1000
|
||||
assert result["details"]["requested"] == 1500
|
||||
assert result["details"]["overage"] == 500
|
||||
|
||||
|
||||
class TestTokenCountError:
|
||||
"""Tests for TokenCountError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic token count error."""
|
||||
error = TokenCountError()
|
||||
assert "token" in error.message.lower()
|
||||
|
||||
def test_with_model_info(self) -> None:
|
||||
"""Test with model information."""
|
||||
error = TokenCountError(
|
||||
message="Failed to count",
|
||||
model="claude-3-sonnet",
|
||||
text_length=5000,
|
||||
)
|
||||
|
||||
assert error.model == "claude-3-sonnet"
|
||||
assert error.text_length == 5000
|
||||
assert error.details["model"] == "claude-3-sonnet"
|
||||
|
||||
|
||||
class TestCompressionError:
|
||||
"""Tests for CompressionError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic compression error."""
|
||||
error = CompressionError()
|
||||
assert "compress" in error.message.lower()
|
||||
|
||||
def test_with_token_info(self) -> None:
|
||||
"""Test with token information."""
|
||||
error = CompressionError(
|
||||
original_tokens=2000,
|
||||
target_tokens=1000,
|
||||
achieved_tokens=1500,
|
||||
)
|
||||
|
||||
assert error.original_tokens == 2000
|
||||
assert error.target_tokens == 1000
|
||||
assert error.achieved_tokens == 1500
|
||||
|
||||
|
||||
class TestAssemblyTimeoutError:
|
||||
"""Tests for AssemblyTimeoutError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic timeout error."""
|
||||
error = AssemblyTimeoutError()
|
||||
assert "timed out" in error.message.lower()
|
||||
|
||||
def test_with_timing_info(self) -> None:
|
||||
"""Test with timing information."""
|
||||
error = AssemblyTimeoutError(
|
||||
timeout_ms=100,
|
||||
elapsed_ms=150.5,
|
||||
stage="scoring",
|
||||
)
|
||||
|
||||
assert error.timeout_ms == 100
|
||||
assert error.elapsed_ms == 150.5
|
||||
assert error.stage == "scoring"
|
||||
assert error.details["stage"] == "scoring"
|
||||
|
||||
|
||||
class TestScoringError:
|
||||
"""Tests for ScoringError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic scoring error."""
|
||||
error = ScoringError()
|
||||
assert "score" in error.message.lower()
|
||||
|
||||
def test_with_scorer_info(self) -> None:
|
||||
"""Test with scorer information."""
|
||||
error = ScoringError(
|
||||
scorer_type="relevance",
|
||||
context_id="ctx-123",
|
||||
)
|
||||
|
||||
assert error.scorer_type == "relevance"
|
||||
assert error.context_id == "ctx-123"
|
||||
|
||||
|
||||
class TestFormattingError:
|
||||
"""Tests for FormattingError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic formatting error."""
|
||||
error = FormattingError()
|
||||
assert "format" in error.message.lower()
|
||||
|
||||
def test_with_model_info(self) -> None:
|
||||
"""Test with model information."""
|
||||
error = FormattingError(
|
||||
model="claude-3-opus",
|
||||
adapter="ClaudeAdapter",
|
||||
)
|
||||
|
||||
assert error.model == "claude-3-opus"
|
||||
assert error.adapter == "ClaudeAdapter"
|
||||
|
||||
|
||||
class TestCacheError:
|
||||
"""Tests for CacheError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic cache error."""
|
||||
error = CacheError()
|
||||
assert "cache" in error.message.lower()
|
||||
|
||||
def test_with_operation_info(self) -> None:
|
||||
"""Test with operation information."""
|
||||
error = CacheError(
|
||||
operation="get",
|
||||
cache_key="ctx:abc123",
|
||||
)
|
||||
|
||||
assert error.operation == "get"
|
||||
assert error.cache_key == "ctx:abc123"
|
||||
|
||||
|
||||
class TestContextNotFoundError:
|
||||
"""Tests for ContextNotFoundError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic not found error."""
|
||||
error = ContextNotFoundError()
|
||||
assert "not found" in error.message.lower()
|
||||
|
||||
def test_with_source_info(self) -> None:
|
||||
"""Test with source information."""
|
||||
error = ContextNotFoundError(
|
||||
source="knowledge-base",
|
||||
query="authentication flow",
|
||||
)
|
||||
|
||||
assert error.source == "knowledge-base"
|
||||
assert error.query == "authentication flow"
|
||||
|
||||
|
||||
class TestInvalidContextError:
|
||||
"""Tests for InvalidContextError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic invalid error."""
|
||||
error = InvalidContextError()
|
||||
assert "invalid" in error.message.lower()
|
||||
|
||||
def test_with_field_info(self) -> None:
|
||||
"""Test with field information."""
|
||||
error = InvalidContextError(
|
||||
field="content",
|
||||
value="",
|
||||
reason="Content cannot be empty",
|
||||
)
|
||||
|
||||
assert error.field == "content"
|
||||
assert error.value == ""
|
||||
assert error.reason == "Content cannot be empty"
|
||||
|
||||
def test_value_type_only_in_details(self) -> None:
|
||||
"""Test that only value type is included in details (not actual value)."""
|
||||
error = InvalidContextError(
|
||||
field="api_key",
|
||||
value="secret-key-here",
|
||||
)
|
||||
|
||||
# Actual value should not be in details
|
||||
assert "secret-key-here" not in str(error.details)
|
||||
assert error.details["value_type"] == "str"
|
||||
499
backend/tests/services/context/test_ranker.py
Normal file
499
backend/tests/services/context/test_ranker.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""Tests for context ranking module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
from app.services.context.prioritization import ContextRanker, RankingResult
|
||||
from app.services.context.scoring import CompositeScorer, ScoredContext
|
||||
from app.services.context.types import (
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
|
||||
class TestRankingResult:
|
||||
"""Tests for RankingResult dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test RankingResult creation."""
|
||||
ctx = TaskContext(content="Test", source="task")
|
||||
scored = ScoredContext(context=ctx, composite_score=0.8)
|
||||
|
||||
result = RankingResult(
|
||||
selected=[scored],
|
||||
excluded=[],
|
||||
total_tokens=100,
|
||||
selection_stats={"total": 1},
|
||||
)
|
||||
|
||||
assert len(result.selected) == 1
|
||||
assert result.total_tokens == 100
|
||||
|
||||
def test_selected_contexts_property(self) -> None:
|
||||
"""Test selected_contexts property extracts contexts."""
|
||||
ctx1 = TaskContext(content="Test 1", source="task")
|
||||
ctx2 = TaskContext(content="Test 2", source="task")
|
||||
|
||||
scored1 = ScoredContext(context=ctx1, composite_score=0.8)
|
||||
scored2 = ScoredContext(context=ctx2, composite_score=0.6)
|
||||
|
||||
result = RankingResult(
|
||||
selected=[scored1, scored2],
|
||||
excluded=[],
|
||||
total_tokens=200,
|
||||
)
|
||||
|
||||
selected = result.selected_contexts
|
||||
assert len(selected) == 2
|
||||
assert ctx1 in selected
|
||||
assert ctx2 in selected
|
||||
|
||||
|
||||
class TestContextRanker:
|
||||
"""Tests for ContextRanker."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test ranker creation."""
|
||||
ranker = ContextRanker()
|
||||
assert ranker._scorer is not None
|
||||
assert ranker._calculator is not None
|
||||
|
||||
def test_creation_with_scorer(self) -> None:
|
||||
"""Test ranker creation with custom scorer."""
|
||||
scorer = CompositeScorer(relevance_weight=0.8)
|
||||
ranker = ContextRanker(scorer=scorer)
|
||||
assert ranker._scorer is scorer
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_empty_contexts(self) -> None:
|
||||
"""Test ranking empty context list."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
result = await ranker.rank([], "query", budget)
|
||||
|
||||
assert len(result.selected) == 0
|
||||
assert len(result.excluded) == 0
|
||||
assert result.total_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_single_context_fits(self) -> None:
|
||||
"""Test ranking single context that fits budget."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Short content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
)
|
||||
|
||||
result = await ranker.rank([context], "query", budget)
|
||||
|
||||
assert len(result.selected) == 1
|
||||
assert len(result.excluded) == 0
|
||||
assert result.selected[0].context is context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_respects_budget(self) -> None:
|
||||
"""Test that ranking respects token budget."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Create a very small budget
|
||||
budget = TokenBudget(
|
||||
total=100,
|
||||
knowledge=50, # Only 50 tokens for knowledge
|
||||
)
|
||||
|
||||
# Create contexts that exceed budget
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A" * 200, # ~50 tokens
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="B" * 200, # ~50 tokens
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="C" * 200, # ~50 tokens
|
||||
source="docs",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
# Not all should fit
|
||||
assert len(result.selected) < len(contexts)
|
||||
assert len(result.excluded) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_selects_highest_scores(self) -> None:
|
||||
"""Test that ranking selects highest scored contexts."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(1000)
|
||||
|
||||
# Small budget for knowledge
|
||||
budget.knowledge = 100
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low score",
|
||||
source="docs",
|
||||
relevance_score=0.2,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High score",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium score",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
# Should have selected some
|
||||
if result.selected:
|
||||
# The highest scored should be selected first
|
||||
scores = [s.composite_score for s in result.selected]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_critical_priority_always_included(self) -> None:
|
||||
"""Test that CRITICAL priority contexts are always included."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Very small budget
|
||||
budget = TokenBudget(
|
||||
total=100,
|
||||
system=10, # Very small
|
||||
knowledge=10,
|
||||
)
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="Critical system prompt that must be included",
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Optional knowledge",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget, ensure_required=True)
|
||||
|
||||
# Critical context should be in selected
|
||||
selected_priorities = [s.context.priority for s in result.selected]
|
||||
assert ContextPriority.CRITICAL.value in selected_priorities
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_without_ensure_required(self) -> None:
|
||||
"""Test ranking without ensuring required contexts."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
budget = TokenBudget(
|
||||
total=100,
|
||||
system=50,
|
||||
knowledge=50,
|
||||
)
|
||||
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="A" * 500, # Large content
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Short",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget, ensure_required=False)
|
||||
|
||||
# Without ensure_required, CRITICAL contexts can be excluded
|
||||
# if budget doesn't allow
|
||||
assert len(result.selected) + len(result.excluded) == len(contexts)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_selection_stats(self) -> None:
|
||||
"""Test that ranking provides useful statistics."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(content="Knowledge 1", source="docs", relevance_score=0.8),
|
||||
KnowledgeContext(content="Knowledge 2", source="docs", relevance_score=0.6),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
stats = result.selection_stats
|
||||
assert "total_contexts" in stats
|
||||
assert "selected_count" in stats
|
||||
assert "excluded_count" in stats
|
||||
assert "total_tokens" in stats
|
||||
assert "by_type" in stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_simple(self) -> None:
|
||||
"""Test simple ranking without budget per type."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="B",
|
||||
source="docs",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="C",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank_simple(contexts, "query", max_tokens=1000)
|
||||
|
||||
# Should return contexts sorted by score
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_simple_respects_max_tokens(self) -> None:
|
||||
"""Test that simple ranking respects max tokens."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Create contexts with known token counts
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="A" * 100, # ~25 tokens
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="B" * 100,
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="C" * 100,
|
||||
source="docs",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
]
|
||||
|
||||
# Very small limit
|
||||
result = await ranker.rank_simple(contexts, "query", max_tokens=30)
|
||||
|
||||
# Should only fit a limited number
|
||||
assert len(result) <= len(contexts)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_simple_empty(self) -> None:
|
||||
"""Test simple ranking with empty list."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
result = await ranker.rank_simple([], "query", max_tokens=1000)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_for_diversity(self) -> None:
|
||||
"""Test diversity reranking."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
# Create scored contexts from same source
|
||||
contexts = [
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content=f"Content {i}",
|
||||
source="same-source",
|
||||
relevance_score=0.9 - i * 0.1,
|
||||
),
|
||||
composite_score=0.9 - i * 0.1,
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
# Limit to 2 per source
|
||||
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||
|
||||
assert len(result) == 5
|
||||
# First 2 should be from same source, rest deferred
|
||||
first_two_sources = [r.context.source for r in result[:2]]
|
||||
assert all(s == "same-source" for s in first_two_sources)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_for_diversity_multiple_sources(self) -> None:
|
||||
"""Test diversity reranking with multiple sources."""
|
||||
ranker = ContextRanker()
|
||||
|
||||
contexts = [
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source A - 1",
|
||||
source="source-a",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
composite_score=0.9,
|
||||
),
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source A - 2",
|
||||
source="source-a",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
composite_score=0.8,
|
||||
),
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source B - 1",
|
||||
source="source-b",
|
||||
relevance_score=0.7,
|
||||
),
|
||||
composite_score=0.7,
|
||||
),
|
||||
ScoredContext(
|
||||
context=KnowledgeContext(
|
||||
content="Source A - 3",
|
||||
source="source-a",
|
||||
relevance_score=0.6,
|
||||
),
|
||||
composite_score=0.6,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||
|
||||
# Should not have more than 2 from source-a in first 3
|
||||
source_a_in_first_3 = sum(
|
||||
1 for r in result[:3] if r.context.source == "source-a"
|
||||
)
|
||||
assert source_a_in_first_3 <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counts_set(self) -> None:
|
||||
"""Test that token counts are set during ranking."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
)
|
||||
|
||||
# Token count should be None initially
|
||||
assert context.token_count is None
|
||||
|
||||
await ranker.rank([context], "query", budget)
|
||||
|
||||
# Token count should be set after ranking
|
||||
assert context.token_count is not None
|
||||
assert context.token_count > 0
|
||||
|
||||
|
||||
class TestContextRankerIntegration:
|
||||
"""Integration tests for context ranking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_ranking_workflow(self) -> None:
|
||||
"""Test complete ranking workflow."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
# Create diverse context types
|
||||
contexts = [
|
||||
SystemContext(
|
||||
content="You are a helpful assistant.",
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
),
|
||||
TaskContext(
|
||||
content="Help the user with their coding question.",
|
||||
source="task",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Python is a programming language.",
|
||||
source="docs/python.md",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Java is also a programming language.",
|
||||
source="docs/java.md",
|
||||
relevance_score=0.4,
|
||||
),
|
||||
ConversationContext(
|
||||
content="Hello, can you help me?",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "Python help", budget)
|
||||
|
||||
# System (CRITICAL) should be included
|
||||
selected_types = [s.context.get_type() for s in result.selected]
|
||||
assert ContextType.SYSTEM in selected_types
|
||||
|
||||
# Stats should be populated
|
||||
assert result.selection_stats["total_contexts"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ranking_preserves_context_order_by_score(self) -> None:
|
||||
"""Test that ranking orders by score correctly."""
|
||||
ranker = ContextRanker()
|
||||
allocator = BudgetAllocator()
|
||||
budget = allocator.create_budget(100000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low",
|
||||
source="docs",
|
||||
relevance_score=0.2,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(contexts, "query", budget)
|
||||
|
||||
# Verify ordering is by score
|
||||
scores = [s.composite_score for s in result.selected]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
893
backend/tests/services/context/test_scoring.py
Normal file
893
backend/tests/services/context/test_scoring.py
Normal file
@@ -0,0 +1,893 @@
|
||||
"""Tests for context scoring module."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.scoring import (
|
||||
CompositeScorer,
|
||||
PriorityScorer,
|
||||
RecencyScorer,
|
||||
RelevanceScorer,
|
||||
ScoredContext,
|
||||
)
|
||||
from app.services.context.types import (
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
|
||||
class TestRelevanceScorer:
|
||||
"""Tests for RelevanceScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation."""
|
||||
scorer = RelevanceScorer()
|
||||
assert scorer.weight == 1.0
|
||||
|
||||
def test_creation_with_weight(self) -> None:
|
||||
"""Test scorer creation with custom weight."""
|
||||
scorer = RelevanceScorer(weight=0.5)
|
||||
assert scorer.weight == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_precomputed_relevance(self) -> None:
|
||||
"""Test scoring with pre-computed relevance score."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
# KnowledgeContext with pre-computed score
|
||||
context = KnowledgeContext(
|
||||
content="Test content about Python",
|
||||
source="docs/python.md",
|
||||
relevance_score=0.85,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "Python programming")
|
||||
assert score == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_metadata_score(self) -> None:
|
||||
"""Test scoring with metadata-provided score."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
context = SystemContext(
|
||||
content="System prompt",
|
||||
source="system",
|
||||
metadata={"relevance_score": 0.9},
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "anything")
|
||||
assert score == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_fallback_to_keyword_matching(self) -> None:
|
||||
"""Test fallback to keyword matching when no score available."""
|
||||
scorer = RelevanceScorer(keyword_fallback_weight=0.5)
|
||||
|
||||
context = TaskContext(
|
||||
content="Implement authentication with JWT tokens",
|
||||
source="task",
|
||||
)
|
||||
|
||||
# Query has matching keywords
|
||||
score = await scorer.score(context, "JWT authentication")
|
||||
assert score > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_matching_no_overlap(self) -> None:
|
||||
"""Test keyword matching with no query overlap."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
context = TaskContext(
|
||||
content="Implement database migration",
|
||||
source="task",
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "xyz abc 123")
|
||||
assert score == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_matching_full_overlap(self) -> None:
|
||||
"""Test keyword matching with high overlap."""
|
||||
scorer = RelevanceScorer(keyword_fallback_weight=1.0)
|
||||
|
||||
context = TaskContext(
|
||||
content="python programming language",
|
||||
source="task",
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "python programming")
|
||||
# Should have high score due to keyword overlap
|
||||
assert score > 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_mcp_success(self) -> None:
|
||||
"""Test scoring with MCP semantic similarity."""
|
||||
mock_mcp = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.success = True
|
||||
mock_result.data = {"similarity": 0.75}
|
||||
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
scorer = RelevanceScorer(mcp_manager=mock_mcp)
|
||||
|
||||
context = TaskContext(
|
||||
content="Test content",
|
||||
source="task",
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "test query")
|
||||
assert score == 0.75
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_mcp_failure_fallback(self) -> None:
|
||||
"""Test fallback when MCP fails."""
|
||||
mock_mcp = MagicMock()
|
||||
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
scorer = RelevanceScorer(mcp_manager=mock_mcp, keyword_fallback_weight=0.5)
|
||||
|
||||
context = TaskContext(
|
||||
content="Python programming code",
|
||||
source="task",
|
||||
)
|
||||
|
||||
# Should fall back to keyword matching
|
||||
score = await scorer.score(context, "Python code")
|
||||
assert score > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
"""Test batch scoring."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(content="Python", source="1", relevance_score=0.8),
|
||||
KnowledgeContext(content="Java", source="2", relevance_score=0.6),
|
||||
KnowledgeContext(content="Go", source="3", relevance_score=0.9),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "test")
|
||||
assert len(scores) == 3
|
||||
assert scores[0] == 0.8
|
||||
assert scores[1] == 0.6
|
||||
assert scores[2] == 0.9
|
||||
|
||||
def test_set_mcp_manager(self) -> None:
|
||||
"""Test setting MCP manager."""
|
||||
scorer = RelevanceScorer()
|
||||
assert scorer._mcp is None
|
||||
|
||||
mock_mcp = MagicMock()
|
||||
scorer.set_mcp_manager(mock_mcp)
|
||||
assert scorer._mcp is mock_mcp
|
||||
|
||||
|
||||
class TestRecencyScorer:
|
||||
"""Tests for RecencyScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation."""
|
||||
scorer = RecencyScorer()
|
||||
assert scorer.weight == 1.0
|
||||
assert scorer._half_life_hours == 24.0
|
||||
|
||||
def test_creation_with_custom_half_life(self) -> None:
|
||||
"""Test scorer creation with custom half-life."""
|
||||
scorer = RecencyScorer(half_life_hours=12.0)
|
||||
assert scorer._half_life_hours == 12.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_recent_context(self) -> None:
|
||||
"""Test scoring a very recent context."""
|
||||
scorer = RecencyScorer(half_life_hours=24.0)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
context = TaskContext(
|
||||
content="Recent task",
|
||||
source="task",
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query", reference_time=now)
|
||||
# Very recent should have score near 1.0
|
||||
assert score > 0.99
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_at_half_life(self) -> None:
|
||||
"""Test scoring at exactly half-life age."""
|
||||
scorer = RecencyScorer(half_life_hours=24.0)
|
||||
now = datetime.now(UTC)
|
||||
half_life_ago = now - timedelta(hours=24)
|
||||
|
||||
context = TaskContext(
|
||||
content="Day old task",
|
||||
source="task",
|
||||
timestamp=half_life_ago,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query", reference_time=now)
|
||||
# At half-life, score should be ~0.5
|
||||
assert 0.49 <= score <= 0.51
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_old_context(self) -> None:
|
||||
"""Test scoring a very old context."""
|
||||
scorer = RecencyScorer(half_life_hours=24.0)
|
||||
now = datetime.now(UTC)
|
||||
week_ago = now - timedelta(days=7)
|
||||
|
||||
context = TaskContext(
|
||||
content="Week old task",
|
||||
source="task",
|
||||
timestamp=week_ago,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query", reference_time=now)
|
||||
# 7 days with 24h half-life = very low score
|
||||
assert score < 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_type_specific_half_lives(self) -> None:
|
||||
"""Test that different context types have different half-lives."""
|
||||
scorer = RecencyScorer()
|
||||
now = datetime.now(UTC)
|
||||
one_hour_ago = now - timedelta(hours=1)
|
||||
|
||||
# Conversation has 1 hour half-life by default
|
||||
conv_context = ConversationContext(
|
||||
content="Hello",
|
||||
source="chat",
|
||||
role=MessageRole.USER,
|
||||
timestamp=one_hour_ago,
|
||||
)
|
||||
|
||||
# Knowledge has 168 hour (1 week) half-life by default
|
||||
knowledge_context = KnowledgeContext(
|
||||
content="Documentation",
|
||||
source="docs",
|
||||
timestamp=one_hour_ago,
|
||||
)
|
||||
|
||||
conv_score = await scorer.score(conv_context, "query", reference_time=now)
|
||||
knowledge_score = await scorer.score(
|
||||
knowledge_context, "query", reference_time=now
|
||||
)
|
||||
|
||||
# Conversation should decay much faster
|
||||
assert conv_score < knowledge_score
|
||||
|
||||
def test_get_half_life(self) -> None:
|
||||
"""Test getting half-life for context type."""
|
||||
scorer = RecencyScorer()
|
||||
|
||||
assert scorer.get_half_life(ContextType.CONVERSATION) == 1.0
|
||||
assert scorer.get_half_life(ContextType.KNOWLEDGE) == 168.0
|
||||
assert scorer.get_half_life(ContextType.SYSTEM) == 720.0
|
||||
|
||||
def test_set_half_life(self) -> None:
|
||||
"""Test setting custom half-life."""
|
||||
scorer = RecencyScorer()
|
||||
|
||||
scorer.set_half_life(ContextType.TASK, 48.0)
|
||||
assert scorer.get_half_life(ContextType.TASK) == 48.0
|
||||
|
||||
def test_set_half_life_invalid(self) -> None:
|
||||
"""Test setting invalid half-life."""
|
||||
scorer = RecencyScorer()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_half_life(ContextType.TASK, 0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_half_life(ContextType.TASK, -1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
"""Test batch scoring."""
|
||||
scorer = RecencyScorer()
|
||||
now = datetime.now(UTC)
|
||||
|
||||
contexts = [
|
||||
TaskContext(content="1", source="t", timestamp=now),
|
||||
TaskContext(content="2", source="t", timestamp=now - timedelta(hours=24)),
|
||||
TaskContext(content="3", source="t", timestamp=now - timedelta(hours=48)),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "query", reference_time=now)
|
||||
assert len(scores) == 3
|
||||
# Scores should be in descending order (more recent = higher)
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
|
||||
class TestPriorityScorer:
|
||||
"""Tests for PriorityScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation."""
|
||||
scorer = PriorityScorer()
|
||||
assert scorer.weight == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_critical_priority(self) -> None:
|
||||
"""Test scoring CRITICAL priority context."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
context = SystemContext(
|
||||
content="Critical system prompt",
|
||||
source="system",
|
||||
priority=ContextPriority.CRITICAL.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query")
|
||||
# CRITICAL (100) + type bonus should be > 1.0, normalized to 1.0
|
||||
assert score == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_normal_priority(self) -> None:
|
||||
"""Test scoring NORMAL priority context."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
context = TaskContext(
|
||||
content="Normal task",
|
||||
source="task",
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query")
|
||||
# NORMAL (50) = 0.5, plus TASK bonus (0.15) = 0.65
|
||||
assert 0.6 <= score <= 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_low_priority(self) -> None:
|
||||
"""Test scoring LOW priority context."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Low priority knowledge",
|
||||
source="docs",
|
||||
priority=ContextPriority.LOW.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "query")
|
||||
# LOW (20) = 0.2, no bonus for KNOWLEDGE
|
||||
assert 0.15 <= score <= 0.25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_type_bonuses(self) -> None:
|
||||
"""Test type-specific priority bonuses."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
# All with same base priority
|
||||
system_ctx = SystemContext(
|
||||
content="System",
|
||||
source="system",
|
||||
priority=50,
|
||||
)
|
||||
task_ctx = TaskContext(
|
||||
content="Task",
|
||||
source="task",
|
||||
priority=50,
|
||||
)
|
||||
knowledge_ctx = KnowledgeContext(
|
||||
content="Knowledge",
|
||||
source="docs",
|
||||
priority=50,
|
||||
)
|
||||
|
||||
system_score = await scorer.score(system_ctx, "query")
|
||||
task_score = await scorer.score(task_ctx, "query")
|
||||
knowledge_score = await scorer.score(knowledge_ctx, "query")
|
||||
|
||||
# System has highest bonus (0.2), task next (0.15), knowledge has none
|
||||
assert system_score > task_score > knowledge_score
|
||||
|
||||
def test_get_type_bonus(self) -> None:
|
||||
"""Test getting type bonus."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
assert scorer.get_type_bonus(ContextType.SYSTEM) == 0.2
|
||||
assert scorer.get_type_bonus(ContextType.TASK) == 0.15
|
||||
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.0
|
||||
|
||||
def test_set_type_bonus(self) -> None:
|
||||
"""Test setting custom type bonus."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
scorer.set_type_bonus(ContextType.KNOWLEDGE, 0.1)
|
||||
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.1
|
||||
|
||||
def test_set_type_bonus_invalid(self) -> None:
|
||||
"""Test setting invalid type bonus."""
|
||||
scorer = PriorityScorer()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_type_bonus(ContextType.KNOWLEDGE, 1.5)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.set_type_bonus(ContextType.KNOWLEDGE, -0.1)
|
||||
|
||||
|
||||
class TestCompositeScorer:
|
||||
"""Tests for CompositeScorer."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test scorer creation with default weights."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
weights = scorer.weights
|
||||
assert weights["relevance"] == 0.5
|
||||
assert weights["recency"] == 0.3
|
||||
assert weights["priority"] == 0.2
|
||||
|
||||
def test_creation_with_custom_weights(self) -> None:
|
||||
"""Test scorer creation with custom weights."""
|
||||
scorer = CompositeScorer(
|
||||
relevance_weight=0.6,
|
||||
recency_weight=0.2,
|
||||
priority_weight=0.2,
|
||||
)
|
||||
|
||||
weights = scorer.weights
|
||||
assert weights["relevance"] == 0.6
|
||||
assert weights["recency"] == 0.2
|
||||
assert weights["priority"] == 0.2
|
||||
|
||||
def test_update_weights(self) -> None:
|
||||
"""Test updating weights."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
scorer.update_weights(relevance=0.7, recency=0.2, priority=0.1)
|
||||
|
||||
weights = scorer.weights
|
||||
assert weights["relevance"] == 0.7
|
||||
assert weights["recency"] == 0.2
|
||||
assert weights["priority"] == 0.1
|
||||
|
||||
def test_update_weights_partial(self) -> None:
|
||||
"""Test partially updating weights."""
|
||||
scorer = CompositeScorer()
|
||||
original_recency = scorer.weights["recency"]
|
||||
|
||||
scorer.update_weights(relevance=0.8)
|
||||
|
||||
assert scorer.weights["relevance"] == 0.8
|
||||
assert scorer.weights["recency"] == original_recency
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_basic(self) -> None:
|
||||
"""Test basic composite scoring."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
timestamp=datetime.now(UTC),
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
score = await scorer.score(context, "test query")
|
||||
assert 0.0 <= score <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_details(self) -> None:
|
||||
"""Test scoring with detailed breakdown."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
timestamp=datetime.now(UTC),
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
|
||||
scored = await scorer.score_with_details(context, "test query")
|
||||
|
||||
assert isinstance(scored, ScoredContext)
|
||||
assert scored.context is context
|
||||
assert 0.0 <= scored.composite_score <= 1.0
|
||||
assert scored.relevance_score == 0.8
|
||||
assert scored.recency_score > 0.9 # Very recent
|
||||
assert scored.priority_score > 0.5 # HIGH priority
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_not_cached_on_context(self) -> None:
|
||||
"""Test that scores are NOT cached on the context.
|
||||
|
||||
Scores should not be cached on the context because they are query-dependent.
|
||||
Different queries would get incorrect cached scores if we cached on the context.
|
||||
"""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
)
|
||||
|
||||
# After scoring, context._score should remain None
|
||||
# (we don't cache on context because scores are query-dependent)
|
||||
await scorer.score(context, "query")
|
||||
# The scorer should compute fresh scores each time
|
||||
# rather than caching on the context object
|
||||
|
||||
# Score again with different query - should compute fresh score
|
||||
score1 = await scorer.score(context, "query 1")
|
||||
score2 = await scorer.score(context, "query 2")
|
||||
# Both should be valid scores (not necessarily equal since queries differ)
|
||||
assert 0.0 <= score1 <= 1.0
|
||||
assert 0.0 <= score2 <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
"""Test batch scoring."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="High relevance",
|
||||
source="docs",
|
||||
relevance_score=0.9,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Low relevance",
|
||||
source="docs",
|
||||
relevance_score=0.2,
|
||||
),
|
||||
]
|
||||
|
||||
scored = await scorer.score_batch(contexts, "query")
|
||||
assert len(scored) == 2
|
||||
assert scored[0].relevance_score > scored[1].relevance_score
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank(self) -> None:
|
||||
"""Test ranking contexts."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(content="Low", source="docs", relevance_score=0.2),
|
||||
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
|
||||
KnowledgeContext(content="Medium", source="docs", relevance_score=0.5),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query")
|
||||
|
||||
# Should be sorted by score (highest first)
|
||||
assert len(ranked) == 3
|
||||
assert ranked[0].relevance_score == 0.9
|
||||
assert ranked[1].relevance_score == 0.5
|
||||
assert ranked[2].relevance_score == 0.2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_with_limit(self) -> None:
|
||||
"""Test ranking with limit."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(content=str(i), source="docs", relevance_score=i / 10)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query", limit=3)
|
||||
assert len(ranked) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_with_min_score(self) -> None:
|
||||
"""Test ranking with minimum score threshold."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(content="Low", source="docs", relevance_score=0.1),
|
||||
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query", min_score=0.5)
|
||||
|
||||
# Only the high relevance context should pass the threshold
|
||||
assert len(ranked) <= 2 # Could be 1 if min_score filters
|
||||
|
||||
def test_set_mcp_manager(self) -> None:
|
||||
"""Test setting MCP manager."""
|
||||
scorer = CompositeScorer()
|
||||
mock_mcp = MagicMock()
|
||||
|
||||
scorer.set_mcp_manager(mock_mcp)
|
||||
assert scorer._relevance_scorer._mcp is mock_mcp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_scoring_same_context_no_race(self) -> None:
|
||||
"""Test that concurrent scoring of the same context doesn't cause race conditions.
|
||||
|
||||
This verifies that the per-context locking mechanism prevents the same context
|
||||
from being scored multiple times when scored concurrently.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Use scorer with recency_weight=0 to eliminate time-dependent variation
|
||||
# (recency scores change as time passes between calls)
|
||||
scorer = CompositeScorer(
|
||||
relevance_weight=0.5,
|
||||
recency_weight=0.0, # Disable recency to get deterministic results
|
||||
priority_weight=0.5,
|
||||
)
|
||||
|
||||
# Create a single context that will be scored multiple times concurrently
|
||||
context = KnowledgeContext(
|
||||
content="Test content for race condition test",
|
||||
source="docs",
|
||||
relevance_score=0.75,
|
||||
)
|
||||
|
||||
# Score the same context many times in parallel
|
||||
num_concurrent = 50
|
||||
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
|
||||
scores = await asyncio.gather(*tasks)
|
||||
|
||||
# All scores should be identical (deterministic scoring without recency)
|
||||
assert all(s == scores[0] for s in scores)
|
||||
# Note: We don't cache _score on context because scores are query-dependent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_scoring_different_contexts(self) -> None:
|
||||
"""Test that concurrent scoring of different contexts works correctly.
|
||||
|
||||
Different contexts should not interfere with each other during parallel scoring.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
scorer = CompositeScorer()
|
||||
|
||||
# Create many different contexts
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content=f"Test content {i}",
|
||||
source="docs",
|
||||
relevance_score=i / 10,
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
# Score all contexts concurrently
|
||||
tasks = [scorer.score(ctx, "test query") for ctx in contexts]
|
||||
scores = await asyncio.gather(*tasks)
|
||||
|
||||
# Each context should have a different score based on its relevance
|
||||
assert len(set(scores)) > 1 # Not all the same
|
||||
# Note: We don't cache _score on context because scores are query-dependent
|
||||
|
||||
|
||||
class TestScoredContext:
|
||||
"""Tests for ScoredContext dataclass."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test ScoredContext creation."""
|
||||
context = TaskContext(content="Test", source="task")
|
||||
scored = ScoredContext(
|
||||
context=context,
|
||||
composite_score=0.75,
|
||||
relevance_score=0.8,
|
||||
recency_score=0.7,
|
||||
priority_score=0.5,
|
||||
)
|
||||
|
||||
assert scored.context is context
|
||||
assert scored.composite_score == 0.75
|
||||
|
||||
def test_comparison_operators(self) -> None:
|
||||
"""Test comparison operators for sorting."""
|
||||
ctx1 = TaskContext(content="1", source="task")
|
||||
ctx2 = TaskContext(content="2", source="task")
|
||||
|
||||
scored1 = ScoredContext(context=ctx1, composite_score=0.5)
|
||||
scored2 = ScoredContext(context=ctx2, composite_score=0.8)
|
||||
|
||||
assert scored1 < scored2
|
||||
assert scored2 > scored1
|
||||
|
||||
def test_sorting(self) -> None:
|
||||
"""Test sorting scored contexts."""
|
||||
contexts = [
|
||||
ScoredContext(
|
||||
context=TaskContext(content="Low", source="task"),
|
||||
composite_score=0.3,
|
||||
),
|
||||
ScoredContext(
|
||||
context=TaskContext(content="High", source="task"),
|
||||
composite_score=0.9,
|
||||
),
|
||||
ScoredContext(
|
||||
context=TaskContext(content="Medium", source="task"),
|
||||
composite_score=0.6,
|
||||
),
|
||||
]
|
||||
|
||||
sorted_contexts = sorted(contexts, reverse=True)
|
||||
|
||||
assert sorted_contexts[0].composite_score == 0.9
|
||||
assert sorted_contexts[1].composite_score == 0.6
|
||||
assert sorted_contexts[2].composite_score == 0.3
|
||||
|
||||
|
||||
class TestBaseScorer:
|
||||
"""Tests for BaseScorer abstract class."""
|
||||
|
||||
def test_weight_property(self) -> None:
|
||||
"""Test weight property."""
|
||||
# Use a concrete implementation
|
||||
scorer = RelevanceScorer(weight=0.7)
|
||||
assert scorer.weight == 0.7
|
||||
|
||||
def test_weight_setter_valid(self) -> None:
|
||||
"""Test weight setter with valid values."""
|
||||
scorer = RelevanceScorer()
|
||||
scorer.weight = 0.5
|
||||
assert scorer.weight == 0.5
|
||||
|
||||
def test_weight_setter_invalid(self) -> None:
|
||||
"""Test weight setter with invalid values."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.weight = -0.1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
scorer.weight = 1.5
|
||||
|
||||
def test_normalize_score(self) -> None:
|
||||
"""Test score normalization."""
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
# Normal range
|
||||
assert scorer.normalize_score(0.5) == 0.5
|
||||
|
||||
# Below 0
|
||||
assert scorer.normalize_score(-0.5) == 0.0
|
||||
|
||||
# Above 1
|
||||
assert scorer.normalize_score(1.5) == 1.0
|
||||
|
||||
# Boundaries
|
||||
assert scorer.normalize_score(0.0) == 0.0
|
||||
assert scorer.normalize_score(1.0) == 1.0
|
||||
|
||||
|
||||
class TestCompositeScorerEdgeCases:
|
||||
"""Tests for CompositeScorer edge cases and lock management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_with_zero_weights(self) -> None:
|
||||
"""Test scoring when all weights are zero."""
|
||||
scorer = CompositeScorer(
|
||||
relevance_weight=0.0,
|
||||
recency_weight=0.0,
|
||||
priority_weight=0.0,
|
||||
)
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test content",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
)
|
||||
|
||||
# Should return 0.0 when total weight is 0
|
||||
score = await scorer.score(context, "test query")
|
||||
assert score == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch_sequential(self) -> None:
|
||||
"""Test batch scoring in sequential mode (parallel=False)."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Content 1",
|
||||
source="docs",
|
||||
relevance_score=0.8,
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Content 2",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
# Use parallel=False to cover the sequential path
|
||||
scored = await scorer.score_batch(contexts, "query", parallel=False)
|
||||
|
||||
assert len(scored) == 2
|
||||
assert scored[0].relevance_score == 0.8
|
||||
assert scored[1].relevance_score == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_fast_path_reuse(self) -> None:
|
||||
"""Test that existing locks are reused via fast path."""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
content="Test",
|
||||
source="docs",
|
||||
relevance_score=0.5,
|
||||
)
|
||||
|
||||
# First access creates the lock
|
||||
lock1 = await scorer._get_context_lock(context.id)
|
||||
|
||||
# Second access should hit the fast path (lock exists in dict)
|
||||
lock2 = await scorer._get_context_lock(context.id)
|
||||
|
||||
assert lock2 is lock1 # Same lock object returned
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_cleanup_when_limit_reached(self) -> None:
|
||||
"""Test that old locks are cleaned up when limit is reached."""
|
||||
import time
|
||||
|
||||
# Create scorer with very low max_locks to trigger cleanup
|
||||
scorer = CompositeScorer()
|
||||
scorer._max_locks = 3
|
||||
scorer._lock_ttl = 0.1 # 100ms TTL
|
||||
|
||||
# Create locks for several context IDs
|
||||
context_ids = [f"ctx-{i}" for i in range(5)]
|
||||
|
||||
# Get locks for first 3 contexts (fill up to limit)
|
||||
for ctx_id in context_ids[:3]:
|
||||
await scorer._get_context_lock(ctx_id)
|
||||
|
||||
# Wait for TTL to expire
|
||||
time.sleep(0.15)
|
||||
|
||||
# Getting a lock for a new context should trigger cleanup
|
||||
await scorer._get_context_lock(context_ids[3])
|
||||
|
||||
# Some old locks should have been cleaned up
|
||||
# The exact number depends on cleanup logic
|
||||
assert len(scorer._context_locks) <= scorer._max_locks + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_cleanup_preserves_held_locks(self) -> None:
|
||||
"""Test that cleanup doesn't remove locks that are currently held."""
|
||||
import time
|
||||
|
||||
scorer = CompositeScorer()
|
||||
scorer._max_locks = 2
|
||||
scorer._lock_ttl = 0.05 # 50ms TTL
|
||||
|
||||
# Get and hold lock1
|
||||
lock1 = await scorer._get_context_lock("ctx-1")
|
||||
async with lock1:
|
||||
# While holding lock1, add more locks
|
||||
await scorer._get_context_lock("ctx-2")
|
||||
time.sleep(0.1) # Let TTL expire
|
||||
# Adding another should trigger cleanup
|
||||
await scorer._get_context_lock("ctx-3")
|
||||
|
||||
# lock1 should still exist (it's held)
|
||||
assert any(lock is lock1 for lock, _ in scorer._context_locks.values())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_lock_acquisition_double_check(self) -> None:
|
||||
"""Test that concurrent lock acquisition uses double-check pattern."""
|
||||
import asyncio
|
||||
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context_id = "test-context-id"
|
||||
|
||||
# Simulate concurrent lock acquisition
|
||||
async def get_lock():
|
||||
return await scorer._get_context_lock(context_id)
|
||||
|
||||
locks = await asyncio.gather(*[get_lock() for _ in range(10)])
|
||||
|
||||
# All should get the same lock (double-check pattern ensures this)
|
||||
assert all(lock is locks[0] for lock in locks)
|
||||
570
backend/tests/services/context/test_types.py
Normal file
570
backend/tests/services/context/test_types.py
Normal file
@@ -0,0 +1,570 @@
|
||||
"""Tests for context types."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestContextType:
|
||||
"""Tests for ContextType enum."""
|
||||
|
||||
def test_all_types_exist(self) -> None:
|
||||
"""Test that all expected context types exist."""
|
||||
assert ContextType.SYSTEM
|
||||
assert ContextType.TASK
|
||||
assert ContextType.KNOWLEDGE
|
||||
assert ContextType.CONVERSATION
|
||||
assert ContextType.TOOL
|
||||
|
||||
def test_from_string_valid(self) -> None:
|
||||
"""Test from_string with valid values."""
|
||||
assert ContextType.from_string("system") == ContextType.SYSTEM
|
||||
assert ContextType.from_string("KNOWLEDGE") == ContextType.KNOWLEDGE
|
||||
assert ContextType.from_string("Task") == ContextType.TASK
|
||||
|
||||
def test_from_string_invalid(self) -> None:
|
||||
"""Test from_string with invalid value."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ContextType.from_string("invalid")
|
||||
assert "Invalid context type" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestContextPriority:
|
||||
"""Tests for ContextPriority enum."""
|
||||
|
||||
def test_priority_ordering(self) -> None:
|
||||
"""Test that priorities are ordered correctly."""
|
||||
assert ContextPriority.LOWEST.value < ContextPriority.LOW.value
|
||||
assert ContextPriority.LOW.value < ContextPriority.NORMAL.value
|
||||
assert ContextPriority.NORMAL.value < ContextPriority.HIGH.value
|
||||
assert ContextPriority.HIGH.value < ContextPriority.HIGHEST.value
|
||||
assert ContextPriority.HIGHEST.value < ContextPriority.CRITICAL.value
|
||||
|
||||
def test_from_int(self) -> None:
|
||||
"""Test from_int conversion."""
|
||||
assert ContextPriority.from_int(0) == ContextPriority.LOWEST
|
||||
assert ContextPriority.from_int(50) == ContextPriority.NORMAL
|
||||
assert ContextPriority.from_int(100) == ContextPriority.HIGHEST
|
||||
assert ContextPriority.from_int(200) == ContextPriority.CRITICAL
|
||||
|
||||
def test_from_int_intermediate(self) -> None:
|
||||
"""Test from_int with intermediate values."""
|
||||
# Should return closest lower priority
|
||||
assert ContextPriority.from_int(30) == ContextPriority.LOW
|
||||
assert ContextPriority.from_int(60) == ContextPriority.NORMAL
|
||||
|
||||
|
||||
class TestSystemContext:
|
||||
"""Tests for SystemContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = SystemContext(
|
||||
content="You are a helpful assistant.",
|
||||
source="system_prompt",
|
||||
)
|
||||
|
||||
assert ctx.content == "You are a helpful assistant."
|
||||
assert ctx.source == "system_prompt"
|
||||
assert ctx.get_type() == ContextType.SYSTEM
|
||||
|
||||
def test_default_high_priority(self) -> None:
|
||||
"""Test that system context defaults to high priority."""
|
||||
ctx = SystemContext(content="Test", source="test")
|
||||
assert ctx.priority == ContextPriority.HIGH.value
|
||||
|
||||
def test_create_persona(self) -> None:
|
||||
"""Test create_persona factory method."""
|
||||
ctx = SystemContext.create_persona(
|
||||
name="Code Assistant",
|
||||
description="A helpful coding assistant.",
|
||||
capabilities=["Write code", "Debug"],
|
||||
constraints=["Never expose secrets"],
|
||||
)
|
||||
|
||||
assert "Code Assistant" in ctx.content
|
||||
assert "helpful coding assistant" in ctx.content
|
||||
assert "Write code" in ctx.content
|
||||
assert "Never expose secrets" in ctx.content
|
||||
assert ctx.priority == ContextPriority.HIGHEST.value
|
||||
|
||||
def test_create_instructions(self) -> None:
|
||||
"""Test create_instructions factory method."""
|
||||
ctx = SystemContext.create_instructions(
|
||||
["Always be helpful", "Be concise"],
|
||||
source="rules",
|
||||
)
|
||||
|
||||
assert "Always be helpful" in ctx.content
|
||||
assert "Be concise" in ctx.content
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test serialization to dict."""
|
||||
ctx = SystemContext(
|
||||
content="Test",
|
||||
source="test",
|
||||
role="assistant",
|
||||
instructions_type="general",
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
assert data["role"] == "assistant"
|
||||
assert data["instructions_type"] == "general"
|
||||
assert data["type"] == "system"
|
||||
|
||||
|
||||
class TestKnowledgeContext:
|
||||
"""Tests for KnowledgeContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = KnowledgeContext(
|
||||
content="def authenticate(user): ...",
|
||||
source="/src/auth.py",
|
||||
collection="code",
|
||||
file_type="python",
|
||||
)
|
||||
|
||||
assert ctx.content == "def authenticate(user): ..."
|
||||
assert ctx.source == "/src/auth.py"
|
||||
assert ctx.collection == "code"
|
||||
assert ctx.get_type() == ContextType.KNOWLEDGE
|
||||
|
||||
def test_from_search_result(self) -> None:
|
||||
"""Test from_search_result factory method."""
|
||||
result = {
|
||||
"content": "Test content",
|
||||
"source_path": "/test/file.py",
|
||||
"collection": "code",
|
||||
"file_type": "python",
|
||||
"chunk_index": 2,
|
||||
"score": 0.85,
|
||||
"id": "chunk-123",
|
||||
}
|
||||
|
||||
ctx = KnowledgeContext.from_search_result(result, "test query")
|
||||
|
||||
assert ctx.content == "Test content"
|
||||
assert ctx.source == "/test/file.py"
|
||||
assert ctx.relevance_score == 0.85
|
||||
assert ctx.search_query == "test query"
|
||||
|
||||
def test_from_search_results(self) -> None:
|
||||
"""Test from_search_results factory method."""
|
||||
results = [
|
||||
{"content": "Content 1", "source_path": "/a.py", "score": 0.9},
|
||||
{"content": "Content 2", "source_path": "/b.py", "score": 0.8},
|
||||
]
|
||||
|
||||
contexts = KnowledgeContext.from_search_results(results, "query")
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert contexts[0].relevance_score == 0.9
|
||||
assert contexts[1].source == "/b.py"
|
||||
|
||||
def test_is_code(self) -> None:
|
||||
"""Test is_code method."""
|
||||
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
|
||||
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
|
||||
|
||||
assert code_ctx.is_code() is True
|
||||
assert doc_ctx.is_code() is False
|
||||
|
||||
def test_is_documentation(self) -> None:
|
||||
"""Test is_documentation method."""
|
||||
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
|
||||
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
|
||||
|
||||
assert doc_ctx.is_documentation() is True
|
||||
assert code_ctx.is_documentation() is False
|
||||
|
||||
|
||||
class TestConversationContext:
|
||||
"""Tests for ConversationContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = ConversationContext(
|
||||
content="Hello, how can I help?",
|
||||
source="conversation",
|
||||
role=MessageRole.ASSISTANT,
|
||||
turn_index=1,
|
||||
)
|
||||
|
||||
assert ctx.content == "Hello, how can I help?"
|
||||
assert ctx.role == MessageRole.ASSISTANT
|
||||
assert ctx.get_type() == ContextType.CONVERSATION
|
||||
|
||||
def test_from_message(self) -> None:
|
||||
"""Test from_message factory method."""
|
||||
ctx = ConversationContext.from_message(
|
||||
content="What is Python?",
|
||||
role="user",
|
||||
turn_index=0,
|
||||
)
|
||||
|
||||
assert ctx.content == "What is Python?"
|
||||
assert ctx.role == MessageRole.USER
|
||||
assert ctx.turn_index == 0
|
||||
|
||||
def test_from_history(self) -> None:
|
||||
"""Test from_history factory method."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "Help me"},
|
||||
]
|
||||
|
||||
contexts = ConversationContext.from_history(messages)
|
||||
|
||||
assert len(contexts) == 3
|
||||
assert contexts[0].role == MessageRole.USER
|
||||
assert contexts[1].role == MessageRole.ASSISTANT
|
||||
assert contexts[2].turn_index == 2
|
||||
|
||||
def test_is_user_message(self) -> None:
|
||||
"""Test is_user_message method."""
|
||||
user_ctx = ConversationContext(
|
||||
content="test", source="test", role=MessageRole.USER
|
||||
)
|
||||
assistant_ctx = ConversationContext(
|
||||
content="test", source="test", role=MessageRole.ASSISTANT
|
||||
)
|
||||
|
||||
assert user_ctx.is_user_message() is True
|
||||
assert assistant_ctx.is_user_message() is False
|
||||
|
||||
def test_format_for_prompt(self) -> None:
|
||||
"""Test format_for_prompt method."""
|
||||
ctx = ConversationContext.from_message(
|
||||
content="What is 2+2?",
|
||||
role="user",
|
||||
)
|
||||
|
||||
formatted = ctx.format_for_prompt()
|
||||
assert "User:" in formatted
|
||||
assert "What is 2+2?" in formatted
|
||||
|
||||
|
||||
class TestTaskContext:
|
||||
"""Tests for TaskContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = TaskContext(
|
||||
content="Implement login feature",
|
||||
source="task",
|
||||
title="Login Feature",
|
||||
)
|
||||
|
||||
assert ctx.content == "Implement login feature"
|
||||
assert ctx.title == "Login Feature"
|
||||
assert ctx.get_type() == ContextType.TASK
|
||||
|
||||
def test_default_normal_priority(self) -> None:
|
||||
"""Test that task context uses NORMAL priority from base class."""
|
||||
ctx = TaskContext(content="Test", source="test")
|
||||
# TaskContext inherits NORMAL priority from BaseContext
|
||||
# Use TaskContext.create() for default HIGH priority behavior
|
||||
assert ctx.priority == ContextPriority.NORMAL.value
|
||||
|
||||
def test_explicit_high_priority(self) -> None:
|
||||
"""Test setting explicit HIGH priority."""
|
||||
ctx = TaskContext(
|
||||
content="Test",
|
||||
source="test",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
assert ctx.priority == ContextPriority.HIGH.value
|
||||
|
||||
def test_create_factory(self) -> None:
|
||||
"""Test create factory method."""
|
||||
ctx = TaskContext.create(
|
||||
title="Add Auth",
|
||||
description="Implement authentication",
|
||||
acceptance_criteria=["Tests pass", "Code reviewed"],
|
||||
constraints=["Use JWT"],
|
||||
issue_id="123",
|
||||
)
|
||||
|
||||
assert ctx.title == "Add Auth"
|
||||
assert ctx.content == "Implement authentication"
|
||||
assert len(ctx.acceptance_criteria) == 2
|
||||
assert "Use JWT" in ctx.constraints
|
||||
assert ctx.status == TaskStatus.IN_PROGRESS
|
||||
|
||||
def test_format_for_prompt(self) -> None:
|
||||
"""Test format_for_prompt method."""
|
||||
ctx = TaskContext.create(
|
||||
title="Test Task",
|
||||
description="Do something",
|
||||
acceptance_criteria=["Works correctly"],
|
||||
)
|
||||
|
||||
formatted = ctx.format_for_prompt()
|
||||
assert "Task: Test Task" in formatted
|
||||
assert "Do something" in formatted
|
||||
assert "Works correctly" in formatted
|
||||
|
||||
def test_status_checks(self) -> None:
|
||||
"""Test status check methods."""
|
||||
pending = TaskContext(content="test", source="test", status=TaskStatus.PENDING)
|
||||
completed = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.COMPLETED
|
||||
)
|
||||
blocked = TaskContext(content="test", source="test", status=TaskStatus.BLOCKED)
|
||||
|
||||
assert pending.is_active() is True
|
||||
assert completed.is_complete() is True
|
||||
assert blocked.is_blocked() is True
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
"""Tests for ToolContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = ToolContext(
|
||||
content="Tool result here",
|
||||
source="tool:search",
|
||||
tool_name="search",
|
||||
)
|
||||
|
||||
assert ctx.tool_name == "search"
|
||||
assert ctx.get_type() == ContextType.TOOL
|
||||
|
||||
def test_from_tool_definition(self) -> None:
|
||||
"""Test from_tool_definition factory method."""
|
||||
ctx = ToolContext.from_tool_definition(
|
||||
name="search_knowledge",
|
||||
description="Search the knowledge base",
|
||||
parameters={
|
||||
"query": {"type": "string", "required": True},
|
||||
"limit": {"type": "integer", "required": False},
|
||||
},
|
||||
server_name="knowledge-base",
|
||||
)
|
||||
|
||||
assert ctx.tool_name == "search_knowledge"
|
||||
assert "Search the knowledge base" in ctx.content
|
||||
assert ctx.is_result is False
|
||||
assert ctx.server_name == "knowledge-base"
|
||||
|
||||
def test_from_tool_result(self) -> None:
|
||||
"""Test from_tool_result factory method."""
|
||||
ctx = ToolContext.from_tool_result(
|
||||
tool_name="search",
|
||||
result={"found": 5, "items": ["a", "b"]},
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
execution_time_ms=150.5,
|
||||
)
|
||||
|
||||
assert ctx.tool_name == "search"
|
||||
assert ctx.is_result is True
|
||||
assert ctx.result_status == ToolResultStatus.SUCCESS
|
||||
assert "found" in ctx.content
|
||||
|
||||
def test_is_successful(self) -> None:
|
||||
"""Test is_successful method."""
|
||||
success = ToolContext.from_tool_result("test", "ok", ToolResultStatus.SUCCESS)
|
||||
error = ToolContext.from_tool_result("test", "error", ToolResultStatus.ERROR)
|
||||
|
||||
assert success.is_successful() is True
|
||||
assert error.is_successful() is False
|
||||
|
||||
def test_format_for_prompt(self) -> None:
|
||||
"""Test format_for_prompt method."""
|
||||
ctx = ToolContext.from_tool_result(
|
||||
"search",
|
||||
"Found 3 results",
|
||||
ToolResultStatus.SUCCESS,
|
||||
)
|
||||
|
||||
formatted = ctx.format_for_prompt()
|
||||
assert "Tool Result" in formatted
|
||||
assert "search" in formatted
|
||||
assert "success" in formatted
|
||||
|
||||
|
||||
class TestAssembledContext:
|
||||
"""Tests for AssembledContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = AssembledContext(
|
||||
content="Assembled content here",
|
||||
total_tokens=500,
|
||||
context_count=5,
|
||||
)
|
||||
|
||||
assert ctx.content == "Assembled content here"
|
||||
assert ctx.total_tokens == 500
|
||||
assert ctx.context_count == 5
|
||||
# Test backward compatibility aliases
|
||||
assert ctx.token_count == 500
|
||||
assert ctx.contexts_included == 5
|
||||
|
||||
def test_budget_utilization(self) -> None:
|
||||
"""Test budget_utilization property."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
total_tokens=800,
|
||||
context_count=5,
|
||||
budget_total=1000,
|
||||
budget_used=800,
|
||||
)
|
||||
|
||||
assert ctx.budget_utilization == 0.8
|
||||
|
||||
def test_budget_utilization_zero_budget(self) -> None:
|
||||
"""Test budget_utilization with zero budget."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
total_tokens=0,
|
||||
context_count=0,
|
||||
budget_total=0,
|
||||
budget_used=0,
|
||||
)
|
||||
|
||||
assert ctx.budget_utilization == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict method."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
total_tokens=100,
|
||||
context_count=2,
|
||||
assembly_time_ms=50.123,
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
assert data["content"] == "test"
|
||||
assert data["total_tokens"] == 100
|
||||
assert data["context_count"] == 2
|
||||
assert data["assembly_time_ms"] == 50.12 # Rounded
|
||||
|
||||
def test_to_json_and_from_json(self) -> None:
|
||||
"""Test JSON serialization round-trip."""
|
||||
original = AssembledContext(
|
||||
content="Test content",
|
||||
total_tokens=100,
|
||||
context_count=3,
|
||||
excluded_count=2,
|
||||
assembly_time_ms=45.5,
|
||||
model="claude-3-sonnet",
|
||||
budget_total=1000,
|
||||
budget_used=100,
|
||||
by_type={"system": 20, "knowledge": 80},
|
||||
cache_hit=True,
|
||||
cache_key="abc123",
|
||||
)
|
||||
|
||||
json_str = original.to_json()
|
||||
restored = AssembledContext.from_json(json_str)
|
||||
|
||||
assert restored.content == original.content
|
||||
assert restored.total_tokens == original.total_tokens
|
||||
assert restored.context_count == original.context_count
|
||||
assert restored.excluded_count == original.excluded_count
|
||||
assert restored.model == original.model
|
||||
assert restored.cache_hit == original.cache_hit
|
||||
assert restored.cache_key == original.cache_key
|
||||
|
||||
|
||||
class TestBaseContextMethods:
|
||||
"""Tests for BaseContext methods."""
|
||||
|
||||
def test_get_age_seconds(self) -> None:
|
||||
"""Test get_age_seconds method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=2)
|
||||
ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||
|
||||
age = ctx.get_age_seconds()
|
||||
# Should be approximately 2 hours in seconds
|
||||
assert 7100 < age < 7300
|
||||
|
||||
def test_get_age_hours(self) -> None:
|
||||
"""Test get_age_hours method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=5)
|
||||
ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||
|
||||
age = ctx.get_age_hours()
|
||||
assert 4.9 < age < 5.1
|
||||
|
||||
def test_is_stale(self) -> None:
|
||||
"""Test is_stale method."""
|
||||
old_time = datetime.now(UTC) - timedelta(days=10)
|
||||
new_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
|
||||
old_ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||
new_ctx = SystemContext(content="test", source="test", timestamp=new_time)
|
||||
|
||||
# Default max_age is 168 hours (7 days)
|
||||
assert old_ctx.is_stale() is True
|
||||
assert new_ctx.is_stale() is False
|
||||
|
||||
def test_token_count_property(self) -> None:
|
||||
"""Test token_count property."""
|
||||
ctx = SystemContext(content="test", source="test")
|
||||
|
||||
# Initially None
|
||||
assert ctx.token_count is None
|
||||
|
||||
# Can be set
|
||||
ctx.token_count = 100
|
||||
assert ctx.token_count == 100
|
||||
|
||||
def test_score_property_clamping(self) -> None:
|
||||
"""Test that score is clamped to 0.0-1.0."""
|
||||
ctx = SystemContext(content="test", source="test")
|
||||
|
||||
ctx.score = 1.5
|
||||
assert ctx.score == 1.0
|
||||
|
||||
ctx.score = -0.5
|
||||
assert ctx.score == 0.0
|
||||
|
||||
ctx.score = 0.75
|
||||
assert ctx.score == 0.75
|
||||
|
||||
def test_hash_and_equality(self) -> None:
|
||||
"""Test hash and equality based on ID."""
|
||||
ctx1 = SystemContext(content="test", source="test")
|
||||
ctx2 = SystemContext(content="test", source="test")
|
||||
ctx3 = SystemContext(content="test", source="test")
|
||||
ctx3.id = ctx1.id # Same ID as ctx1
|
||||
|
||||
# Different IDs = not equal
|
||||
assert ctx1 != ctx2
|
||||
|
||||
# Same ID = equal
|
||||
assert ctx1 == ctx3
|
||||
|
||||
# Can be used in sets
|
||||
context_set = {ctx1, ctx2, ctx3}
|
||||
assert len(context_set) == 2 # ctx1 and ctx3 are same
|
||||
|
||||
def test_truncate(self) -> None:
|
||||
"""Test truncate method."""
|
||||
long_content = "word " * 1000 # Long content
|
||||
ctx = SystemContext(content=long_content, source="test")
|
||||
ctx.token_count = 1000
|
||||
|
||||
truncated = ctx.truncate(100)
|
||||
|
||||
assert len(truncated) < len(long_content)
|
||||
assert "[truncated]" in truncated
|
||||
@@ -14,20 +14,19 @@ from app.services.mcp.client_manager import (
|
||||
shutdown_mcp_client,
|
||||
)
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||
from app.services.mcp.connection import ConnectionState
|
||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||
from app.services.mcp.registry import MCPServerRegistry
|
||||
from app.services.mcp.routing import ToolInfo, ToolResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_registry():
|
||||
async def reset_registry():
|
||||
"""Reset the singleton registry before and after each test."""
|
||||
MCPServerRegistry.reset_instance()
|
||||
reset_mcp_client()
|
||||
await reset_mcp_client()
|
||||
yield
|
||||
MCPServerRegistry.reset_instance()
|
||||
reset_mcp_client()
|
||||
await reset_mcp_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -389,7 +388,8 @@ class TestModuleLevelFunctions:
|
||||
mock_shutdown.return_value = None
|
||||
await shutdown_mcp_client()
|
||||
|
||||
def test_reset_mcp_client(self, reset_registry):
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_mcp_client(self, reset_registry):
|
||||
"""Test resetting the global client."""
|
||||
reset_mcp_client()
|
||||
await reset_mcp_client()
|
||||
# Should not raise
|
||||
|
||||
@@ -4,10 +4,8 @@ Tests for MCP Configuration System
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from app.services.mcp.config import (
|
||||
MCPConfig,
|
||||
@@ -217,9 +215,7 @@ mcp_servers:
|
||||
default_timeout: 60
|
||||
connection_pool_size: 20
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
@@ -248,9 +244,7 @@ mcp_servers:
|
||||
explicit-server:
|
||||
url: http://explicit:8000
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
@@ -267,9 +261,7 @@ mcp_servers:
|
||||
env-server:
|
||||
url: http://env:8000
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
|
||||
@@ -220,9 +220,7 @@ class TestMCPConnection:
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
result = await conn.execute_request(
|
||||
"POST", "/mcp", data={"method": "test"}
|
||||
)
|
||||
result = await conn.execute_request("POST", "/mcp", data={"method": "test"})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
|
||||
@@ -160,11 +160,21 @@ class TestMCPToolNotFoundError:
|
||||
"""Test tool not found with available tools listed."""
|
||||
error = MCPToolNotFoundError(
|
||||
"unknown-tool",
|
||||
available_tools=["tool-1", "tool-2", "tool-3", "tool-4", "tool-5", "tool-6"],
|
||||
available_tools=[
|
||||
"tool-1",
|
||||
"tool-2",
|
||||
"tool-3",
|
||||
"tool-4",
|
||||
"tool-5",
|
||||
"tool-6",
|
||||
],
|
||||
)
|
||||
assert len(error.available_tools) == 6
|
||||
# Should show first 5 tools with ellipsis
|
||||
assert "available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..." in str(error)
|
||||
assert (
|
||||
"available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..."
|
||||
in str(error)
|
||||
)
|
||||
|
||||
|
||||
class TestMCPCircuitOpenError:
|
||||
|
||||
@@ -4,7 +4,7 @@ Tests for MCP Server Registry
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig, TransportType
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||
from app.services.mcp.registry import (
|
||||
MCPServerRegistry,
|
||||
|
||||
@@ -2,15 +2,13 @@
|
||||
Tests for MCP Tool Call Routing
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||
from app.services.mcp.connection import ConnectionPool
|
||||
from app.services.mcp.exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
)
|
||||
from app.services.mcp.registry import MCPServerRegistry
|
||||
@@ -79,7 +77,10 @@ class TestToolInfo:
|
||||
name="create_issue",
|
||||
description="Create a new issue",
|
||||
server_name="issues",
|
||||
input_schema={"type": "object", "properties": {"title": {"type": "string"}}},
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}},
|
||||
},
|
||||
)
|
||||
assert info.name == "create_issue"
|
||||
assert info.description == "Create a new issue"
|
||||
@@ -174,9 +175,7 @@ class TestToolRouter:
|
||||
|
||||
# Mock the pool connection and request
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute_request = AsyncMock(
|
||||
return_value={"result": {"status": "ok"}}
|
||||
)
|
||||
mock_conn.execute_request = AsyncMock(return_value={"result": {"status": "ok"}})
|
||||
mock_conn.is_connected = True
|
||||
|
||||
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
||||
@@ -232,9 +231,7 @@ class TestToolRouter:
|
||||
await router.register_tool_mapping("tool-on-server-1", "server-1")
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute_request = AsyncMock(
|
||||
return_value={"result": "routed"}
|
||||
)
|
||||
mock_conn.execute_request = AsyncMock(return_value={"result": "routed"})
|
||||
mock_conn.is_connected = True
|
||||
|
||||
with patch.object(router._pool, "get_connection", return_value=mock_conn):
|
||||
@@ -339,7 +336,11 @@ class TestToolRouter:
|
||||
delay2 = router._calculate_retry_delay(2, config)
|
||||
delay3 = router._calculate_retry_delay(3, config)
|
||||
|
||||
# Delays should increase with attempts
|
||||
# All delays should be positive
|
||||
assert delay1 > 0
|
||||
# Allow for jitter variation
|
||||
assert delay2 > 0
|
||||
assert delay3 > 0
|
||||
# All delays should be within max bounds (allow for jitter variation)
|
||||
assert delay1 <= config.retry_max_delay * 1.25
|
||||
assert delay2 <= config.retry_max_delay * 1.25
|
||||
assert delay3 <= config.retry_max_delay * 1.25
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user